diff --git a/src/scala/microsoft-spark-3-3/pom.xml b/src/scala/microsoft-spark-3-3/pom.xml
new file mode 100644
index 000000000..ddfbab676
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/pom.xml
@@ -0,0 +1,83 @@
+
+ 4.0.0
+
+ com.microsoft.scala
+ microsoft-spark
+ ${microsoft-spark.version}
+
+ microsoft-spark-3-3_2.12
+ 2019
+
+ UTF-8
+ 2.12.10
+ 2.12
+ 3.3.0
+
+
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-mllib_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ junit
+ junit
+ 4.13.1
+ test
+
+
+ org.specs
+ specs
+ 1.2.5
+ test
+
+
+
+
+ src/main/scala
+ src/test/scala
+
+
+ org.scala-tools
+ maven-scala-plugin
+ 2.15.2
+
+
+
+ compile
+ testCompile
+
+
+
+
+ ${scala.version}
+
+ -target:jvm-1.8
+ -deprecation
+ -feature
+
+
+
+
+
+
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala
new file mode 100644
index 000000000..aea355dfa
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.io.DataOutputStream
+
+import org.apache.spark.internal.Logging
+
+import scala.collection.mutable.Queue
+
+/**
+ * CallbackClient is used to communicate with the Dotnet CallbackServer.
+ * The client manages and maintains a pool of open CallbackConnections.
+ * Any callback request is delegated to a new CallbackConnection or
+ * unused CallbackConnection.
+ * @param address The address of the Dotnet CallbackServer
+ * @param port The port of the Dotnet CallbackServer
+ */
+class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging {
+ private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]()
+
+ private[this] var isShutdown: Boolean = false
+
+ final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit =
+ getOrCreateConnection() match {
+ case Some(connection) =>
+ try {
+ connection.send(callbackId, writeBody)
+ addConnection(connection)
+ } catch {
+ case e: Exception =>
+ logError(s"Error calling callback [callback id = $callbackId].", e)
+ connection.close()
+ throw e
+ }
+ case None => throw new Exception("Unable to get or create connection.")
+ }
+
+ private def getOrCreateConnection(): Option[CallbackConnection] = synchronized {
+ if (isShutdown) {
+ logInfo("Cannot get or create connection while client is shutdown.")
+ return None
+ }
+
+ if (connectionPool.nonEmpty) {
+ return Some(connectionPool.dequeue())
+ }
+
+ Some(new CallbackConnection(serDe, address, port))
+ }
+
+ private def addConnection(connection: CallbackConnection): Unit = synchronized {
+ assert(connection != null)
+ connectionPool.enqueue(connection)
+ }
+
+ def shutdown(): Unit = synchronized {
+ if (isShutdown) {
+ logInfo("Shutdown called, but already shutdown.")
+ return
+ }
+
+ logInfo("Shutting down.")
+ connectionPool.foreach(_.close)
+ connectionPool.clear
+ isShutdown = true
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala
new file mode 100644
index 000000000..604cf029b
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.io.{ByteArrayOutputStream, Closeable, DataInputStream, DataOutputStream}
+import java.net.Socket
+
+import org.apache.spark.internal.Logging
+
+/**
+ * CallbackConnection is used to process the callback communication
+ * between the JVM and Dotnet. It uses a TCP socket to communicate with
+ * the Dotnet CallbackServer and the socket is expected to be reused.
+ * @param address The address of the Dotnet CallbackServer
+ * @param port The port of the Dotnet CallbackServer
+ */
+class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging {
+ private[this] val socket: Socket = new Socket(address, port)
+ private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream)
+ private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream)
+
+ def send(
+ callbackId: Int,
+ writeBody: (DataOutputStream, SerDe) => Unit): Unit = {
+ logInfo(s"Calling callback [callback id = $callbackId] ...")
+
+ try {
+ serDe.writeInt(outputStream, CallbackFlags.CALLBACK)
+ serDe.writeInt(outputStream, callbackId)
+
+ val byteArrayOutputStream = new ByteArrayOutputStream()
+ writeBody(new DataOutputStream(byteArrayOutputStream), serDe)
+ serDe.writeInt(outputStream, byteArrayOutputStream.size)
+ byteArrayOutputStream.writeTo(outputStream);
+ } catch {
+ case e: Exception => {
+ throw new Exception("Error writing to stream.", e)
+ }
+ }
+
+ logInfo(s"Signaling END_OF_STREAM.")
+ try {
+ serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM)
+ outputStream.flush()
+
+ val endOfStreamResponse = readFlag(inputStream)
+ endOfStreamResponse match {
+ case CallbackFlags.END_OF_STREAM =>
+ logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.")
+ case _ => {
+ throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " +
+ s"Received: $endOfStreamResponse")
+ }
+ }
+ } catch {
+ case e: Exception => {
+ throw new Exception("Error while verifying end of stream.", e)
+ }
+ }
+ }
+
+ def close(): Unit = {
+ try {
+ serDe.writeInt(outputStream, CallbackFlags.CLOSE)
+ outputStream.flush()
+ } catch {
+ case e: Exception => logInfo("Unable to send close to .NET callback server.", e)
+ }
+
+ close(socket)
+ close(outputStream)
+ close(inputStream)
+ }
+
+ private def close(s: Socket): Unit = {
+ try {
+ assert(s != null)
+ s.close()
+ } catch {
+ case e: Exception => logInfo("Unable to close socket.", e)
+ }
+ }
+
+ private def close(c: Closeable): Unit = {
+ try {
+ assert(c != null)
+ c.close()
+ } catch {
+ case e: Exception => logInfo("Unable to close closeable.", e)
+ }
+ }
+
+ private def readFlag(inputStream: DataInputStream): Int = {
+ val callbackFlag = serDe.readInt(inputStream)
+ if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) {
+ val exceptionMessage = serDe.readString(inputStream)
+ throw new DotnetException(exceptionMessage)
+ }
+ callbackFlag
+ }
+
+ private object CallbackFlags {
+ val CLOSE: Int = -1
+ val CALLBACK: Int = -2
+ val DOTNET_EXCEPTION_THROWN: Int = -3
+ val END_OF_STREAM: Int = -4
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala
new file mode 100644
index 000000000..c6f528aee
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.net.InetSocketAddress
+import java.util.concurrent.TimeUnit
+import io.netty.bootstrap.ServerBootstrap
+import io.netty.channel.nio.NioEventLoopGroup
+import io.netty.channel.socket.SocketChannel
+import io.netty.channel.socket.nio.NioServerSocketChannel
+import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder
+import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_NUM_BACKEND_THREADS
+import org.apache.spark.{SparkConf, SparkEnv}
+
+/**
+ * Netty server that invokes JVM calls based upon receiving messages from .NET.
+ * The implementation mirrors the RBackend.
+ *
+ */
+class DotnetBackend extends Logging {
+ self => // for accessing the this reference in inner class(ChannelInitializer)
+ private[this] var channelFuture: ChannelFuture = _
+ private[this] var bootstrap: ServerBootstrap = _
+ private[this] var bossGroup: EventLoopGroup = _
+ private[this] val objectTracker = new JVMObjectTracker
+
+ @volatile
+ private[dotnet] var callbackClient: Option[CallbackClient] = None
+
+ def init(portNumber: Int): Int = {
+ val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
+ val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS)
+ logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.")
+ bossGroup = new NioEventLoopGroup(numBackendThreads)
+ val workerGroup = bossGroup
+
+ bootstrap = new ServerBootstrap()
+ .group(bossGroup, workerGroup)
+ .channel(classOf[NioServerSocketChannel])
+
+ bootstrap.childHandler(new ChannelInitializer[SocketChannel]() {
+ def initChannel(ch: SocketChannel): Unit = {
+ ch.pipeline()
+ .addLast("encoder", new ByteArrayEncoder())
+ .addLast(
+ "frameDecoder",
+ // maxFrameLength = 2G
+ // lengthFieldOffset = 0
+ // lengthFieldLength = 4
+ // lengthAdjustment = 0
+ // initialBytesToStrip = 4, i.e. strip out the length field itself
+ new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
+ .addLast("decoder", new ByteArrayDecoder())
+ .addLast("handler", new DotnetBackendHandler(self, objectTracker))
+ }
+ })
+
+ channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber))
+ channelFuture.syncUninterruptibly()
+ channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort
+ }
+
+ private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized {
+ callbackClient = callbackClient match {
+ case Some(_) => throw new Exception("Callback client already set.")
+ case None =>
+ logInfo(s"Connecting to a callback server at $address:$port")
+ Some(new CallbackClient(new SerDe(objectTracker), address, port))
+ }
+ }
+
+ private[dotnet] def shutdownCallbackClient(): Unit = synchronized {
+ callbackClient match {
+ case Some(client) => client.shutdown()
+ case None => logInfo("Callback server has already been shutdown.")
+ }
+ callbackClient = None
+ }
+
+ def run(): Unit = {
+ channelFuture.channel.closeFuture().syncUninterruptibly()
+ }
+
+ def close(): Unit = {
+ if (channelFuture != null) {
+ // close is a local operation and should finish within milliseconds; timeout just to be safe
+ channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS)
+ channelFuture = null
+ }
+ if (bootstrap != null && bootstrap.config().group() != null) {
+ bootstrap.config().group().shutdownGracefully()
+ }
+ if (bootstrap != null && bootstrap.config().childGroup() != null) {
+ bootstrap.config().childGroup().shutdownGracefully()
+ }
+ bootstrap = null
+
+ objectTracker.clear()
+
+ // Send close to .NET callback server.
+ shutdownCallbackClient()
+
+ // Shutdown the thread pool whose executors could still be running.
+ ThreadPool.shutdown()
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
new file mode 100644
index 000000000..fc3a5c911
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import scala.collection.mutable.HashMap
+import scala.language.existentials
+
+/**
+ * Handler for DotnetBackend.
+ * This implementation is similar to RBackendHandler.
+ */
+class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker)
+ extends SimpleChannelInboundHandler[Array[Byte]]
+ with Logging {
+
+ private[this] val serDe = new SerDe(objectsTracker)
+
+ override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
+ val reply = handleBackendRequest(msg)
+ ctx.write(reply)
+ }
+
+ override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {
+ ctx.flush()
+ }
+
+ def handleBackendRequest(msg: Array[Byte]): Array[Byte] = {
+ val bis = new ByteArrayInputStream(msg)
+ val dis = new DataInputStream(bis)
+
+ val bos = new ByteArrayOutputStream()
+ val dos = new DataOutputStream(bos)
+
+ // First bit is isStatic
+ val isStatic = serDe.readBoolean(dis)
+ val processId = serDe.readInt(dis)
+ val threadId = serDe.readInt(dis)
+ val objId = serDe.readString(dis)
+ val methodName = serDe.readString(dis)
+ val numArgs = serDe.readInt(dis)
+
+ if (objId == "DotnetHandler") {
+ methodName match {
+ case "stopBackend" =>
+ serDe.writeInt(dos, 0)
+ serDe.writeType(dos, "void")
+ server.close()
+ case "rm" =>
+ try {
+ val t = serDe.readObjectType(dis)
+ assert(t == 'c')
+ val objToRemove = serDe.readString(dis)
+ objectsTracker.remove(objToRemove)
+ serDe.writeInt(dos, 0)
+ serDe.writeObject(dos, null)
+ } catch {
+ case e: Exception =>
+ logError(s"Removing $objId failed", e)
+ serDe.writeInt(dos, -1)
+ }
+ case "rmThread" =>
+ try {
+ assert(serDe.readObjectType(dis) == 'i')
+ val processId = serDe.readInt(dis)
+ assert(serDe.readObjectType(dis) == 'i')
+ val threadToDelete = serDe.readInt(dis)
+ val result = ThreadPool.tryDeleteThread(processId, threadToDelete)
+ serDe.writeInt(dos, 0)
+ serDe.writeObject(dos, result.asInstanceOf[AnyRef])
+ } catch {
+ case e: Exception =>
+ logError(s"Removing thread $threadId failed", e)
+ serDe.writeInt(dos, -1)
+ }
+ case "connectCallback" =>
+ assert(serDe.readObjectType(dis) == 'c')
+ val address = serDe.readString(dis)
+ assert(serDe.readObjectType(dis) == 'i')
+ val port = serDe.readInt(dis)
+ server.setCallbackClient(address, port)
+ serDe.writeInt(dos, 0)
+
+ // Sends reference of CallbackClient to dotnet side,
+ // so that dotnet process can send the client back to Java side
+ // when calling any API containing callback functions.
+ serDe.writeObject(dos, server.callbackClient)
+ case "closeCallback" =>
+ logInfo("Requesting to close callback client")
+ server.shutdownCallbackClient()
+ serDe.writeInt(dos, 0)
+ serDe.writeType(dos, "void")
+ case _ => dos.writeInt(-1)
+ }
+ } else {
+ ThreadPool
+ .run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
+ }
+
+ bos.toByteArray
+ }
+
+ override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
+ // Skip logging the exception message if the connection was disconnected from
+ // the .NET side so that .NET side doesn't have to explicitly close the connection via
+ // "stopBackend." Note that an exception is still thrown if the exit status is non-zero,
+ // so skipping this kind of exception message does not affect the debugging.
+ if (!cause.getMessage.contains(
+ "An existing connection was forcibly closed by the remote host")) {
+ logError("Exception caught: ", cause)
+ }
+
+ // Close the connection when an exception is raised.
+ ctx.close()
+ }
+
+ def handleMethodCall(
+ isStatic: Boolean,
+ objId: String,
+ methodName: String,
+ numArgs: Int,
+ dis: DataInputStream,
+ dos: DataOutputStream): Unit = {
+ var obj: Object = null
+ var args: Array[java.lang.Object] = null
+ var methods: Array[java.lang.reflect.Method] = null
+
+ try {
+ val cls = if (isStatic) {
+ Utils.classForName(objId)
+ } else {
+ objectsTracker.get(objId) match {
+ case None => throw new IllegalArgumentException("Object not found " + objId)
+ case Some(o) =>
+ obj = o
+ o.getClass
+ }
+ }
+
+ args = readArgs(numArgs, dis)
+ methods = cls.getMethods
+
+ val selectedMethods = methods.filter(m => m.getName == methodName)
+ if (selectedMethods.length > 0) {
+ val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args)
+
+ if (index.isEmpty) {
+ logWarning(
+ s"cannot find matching method ${cls}.$methodName. "
+ + s"Candidates are:")
+ selectedMethods.foreach { method =>
+ logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})")
+ }
+ throw new Exception(s"No matched method found for $cls.$methodName")
+ }
+
+ val ret = selectedMethods(index.get).invoke(obj, args: _*)
+
+ // Write status bit
+ serDe.writeInt(dos, 0)
+ serDe.writeObject(dos, ret.asInstanceOf[AnyRef])
+ } else if (methodName == "") {
+ // methodName should be "" for constructor
+ val ctor = cls.getConstructors.filter { x =>
+ matchMethod(numArgs, args, x.getParameterTypes)
+ }.head
+
+ val obj = ctor.newInstance(args: _*)
+
+ serDe.writeInt(dos, 0)
+ serDe.writeObject(dos, obj.asInstanceOf[AnyRef])
+ } else {
+ throw new IllegalArgumentException(
+ "invalid method " + methodName + " for object " + objId)
+ }
+ } catch {
+ case e: Throwable =>
+ val jvmObj = objectsTracker.get(objId)
+ val jvmObjName = jvmObj match {
+ case Some(jObj) => jObj.getClass.getName
+ case None => "NullObject"
+ }
+ val argsStr = args
+ .map(arg => {
+ if (arg != null) {
+ s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
+ } else {
+ "[Value: NULL]"
+ }
+ })
+ .mkString(", ")
+
+ logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
+
+ if (methods != null) {
+ logDebug(s"All methods for $jvmObjName:")
+ methods.foreach(m => logDebug(m.toString))
+ }
+
+ serDe.writeInt(dos, -1)
+ serDe.writeString(dos, Utils.exceptionString(e.getCause))
+ }
+ }
+
+ // Read a number of arguments from the data input stream
+ def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
+ (0 until numArgs).map { arg =>
+ serDe.readObject(dis)
+ }.toArray
+ }
+
+ // Checks if the arguments passed in args matches the parameter types.
+ // NOTE: Currently we do exact match. We may add type conversions later.
+ def matchMethod(
+ numArgs: Int,
+ args: Array[java.lang.Object],
+ parameterTypes: Array[Class[_]]): Boolean = {
+ if (parameterTypes.length != numArgs) {
+ return false
+ }
+
+ for (i <- 0 until numArgs) {
+ val parameterType = parameterTypes(i)
+ var parameterWrapperType = parameterType
+
+ // Convert native parameters to Object types as args is Array[Object] here
+ if (parameterType.isPrimitive) {
+ parameterWrapperType = parameterType match {
+ case java.lang.Integer.TYPE => classOf[java.lang.Integer]
+ case java.lang.Long.TYPE => classOf[java.lang.Long]
+ case java.lang.Double.TYPE => classOf[java.lang.Double]
+ case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
+ case _ => parameterType
+ }
+ }
+
+ if (!parameterWrapperType.isInstance(args(i))) {
+ // non primitive types
+ if (!parameterType.isPrimitive && args(i) != null) {
+ return false
+ }
+
+ // primitive types
+ if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) {
+ return false
+ }
+ }
+ }
+
+ true
+ }
+
+ // Find a matching method signature in an array of signatures of constructors
+ // or methods of the same name according to the passed arguments. Arguments
+ // may be converted in order to match a signature.
+ //
+ // Note that in Java reflection, constructors and normal methods are of different
+ // classes, and share no parent class that provides methods for reflection uses.
+ // There is no unified way to handle them in this function. So an array of signatures
+ // is passed in instead of an array of candidate constructors or methods.
+ //
+ // Returns an Option[Int] which is the index of the matched signature in the array.
+ def findMatchedSignature(
+ parameterTypesOfMethods: Array[Array[Class[_]]],
+ args: Array[Object]): Option[Int] = {
+ val numArgs = args.length
+
+ for (index <- parameterTypesOfMethods.indices) {
+ val parameterTypes = parameterTypesOfMethods(index)
+
+ if (parameterTypes.length == numArgs) {
+ var argMatched = true
+ var i = 0
+ while (i < numArgs && argMatched) {
+ val parameterType = parameterTypes(i)
+
+ if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) {
+ // The case that the parameter type is a Scala Seq and the argument
+ // is a Java array is considered matching. The array will be converted
+ // to a Seq later if this method is matched.
+ } else {
+ var parameterWrapperType = parameterType
+
+ // Convert native parameters to Object types as args is Array[Object] here
+ if (parameterType.isPrimitive) {
+ parameterWrapperType = parameterType match {
+ case java.lang.Integer.TYPE => classOf[java.lang.Integer]
+ case java.lang.Long.TYPE => classOf[java.lang.Long]
+ case java.lang.Double.TYPE => classOf[java.lang.Double]
+ case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
+ case _ => parameterType
+ }
+ }
+ if ((parameterType.isPrimitive || args(i) != null) &&
+ !parameterWrapperType.isInstance(args(i))) {
+ argMatched = false
+ }
+ }
+
+ i = i + 1
+ }
+
+ if (argMatched) {
+ // For now, we return the first matching method.
+ // TODO: find best method in matching methods.
+
+ // Convert args if needed
+ val parameterTypes = parameterTypesOfMethods(index)
+
+ for (i <- 0 until numArgs) {
+ if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) {
+ // Convert a Java array to scala Seq
+ args(i) = args(i).asInstanceOf[Array[_]].toSeq
+ }
+ }
+
+ return Some(index)
+ }
+ }
+ }
+ None
+ }
+
+ def logError(id: String, e: Exception): Unit = {}
+}
+
+
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala
new file mode 100644
index 000000000..c70d16b03
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala
@@ -0,0 +1,13 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+class DotnetException(message: String, cause: Throwable)
+ extends Exception(message, cause) {
+
+ def this(message: String) = this(message, null)
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala
new file mode 100644
index 000000000..f5277c215
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import org.apache.spark.SparkContext
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python._
+import org.apache.spark.rdd.RDD
+
+object DotnetRDD {
+ def createPythonRDD(
+ parent: RDD[_],
+ func: PythonFunction,
+ preservePartitoning: Boolean): PythonRDD = {
+ new PythonRDD(parent, func, preservePartitoning)
+ }
+
+ def createJavaRDDFromArray(
+ sc: SparkContext,
+ arr: Array[Array[Byte]],
+ numSlices: Int): JavaRDD[Array[Byte]] = {
+ JavaRDD.fromRDD(sc.parallelize(arr, numSlices))
+ }
+
+ def toJavaRDD(rdd: RDD[_]): JavaRDD[_] = JavaRDD.fromRDD(rdd)
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
new file mode 100644
index 000000000..9f556338b
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import scala.collection.JavaConverters._
+
+/** DotnetUtils object that hosts some helper functions
+ * help data type conversions between dotnet and scala
+ */
+object DotnetUtils {
+
+ /** A helper function to convert scala Map to java.util.Map
+ * @param value - scala Map
+ * @return java.util.Map
+ */
+ def convertToJavaMap(value: Map[_, _]): java.util.Map[_, _] = value.asJava
+
+ /** Convert java data type to corresponding scala type
+ * @param value - java.lang.Object
+ * @return Any
+ */
+ def mapScalaToJava(value: java.lang.Object): Any = {
+ value match {
+ case i: java.lang.Integer => i.toInt
+ case d: java.lang.Double => d.toDouble
+ case f: java.lang.Float => f.toFloat
+ case b: java.lang.Boolean => b.booleanValue()
+ case l: java.lang.Long => l.toLong
+ case s: java.lang.Short => s.toShort
+ case by: java.lang.Byte => by.toByte
+ case c: java.lang.Character => c.toChar
+ case _ => value
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala
new file mode 100644
index 000000000..81cfaf88b
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+
+package org.apache.spark.api.dotnet
+
+import scala.collection.mutable.HashMap
+
+/**
+ * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects.
+ */
+private[dotnet] class JVMObjectTracker {
+
+ // Multiple threads may access objMap and increase objCounter. Because get method return Option,
+ // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap.
+ private[this] val objMap = new HashMap[String, Object]
+ private[this] var objCounter: Int = 1
+
+ def getObject(id: String): Object = {
+ synchronized {
+ objMap(id)
+ }
+ }
+
+ def get(id: String): Option[Object] = {
+ synchronized {
+ objMap.get(id)
+ }
+ }
+
+ def put(obj: Object): String = {
+ synchronized {
+ val objId = objCounter.toString
+ objCounter = objCounter + 1
+ objMap.put(objId, obj)
+ objId
+ }
+ }
+
+ def remove(id: String): Option[Object] = {
+ synchronized {
+ objMap.remove(id)
+ }
+ }
+
+ def clear(): Unit = {
+ synchronized {
+ objMap.clear()
+ objCounter = 1
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala
new file mode 100644
index 000000000..06a476f67
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.sql.api.dotnet
+
+import org.apache.spark.SparkConf
+
+/*
+ * Utils for JvmBridge.
+ */
+object JvmBridgeUtils {
+ def getKeyValuePairAsString(kvp: (String, String)): String = {
+ return kvp._1 + "=" + kvp._2
+ }
+
+ def getKeyValuePairArrayAsString(kvpArray: Array[(String, String)]): String = {
+ val sb = new StringBuilder
+
+ for (kvp <- kvpArray) {
+ sb.append(getKeyValuePairAsString(kvp))
+ sb.append(";")
+ }
+
+ sb.toString
+ }
+
+ def getSparkConfAsString(sparkConf: SparkConf): String = {
+ getKeyValuePairArrayAsString(sparkConf.getAll)
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
new file mode 100644
index 000000000..a3df3788a
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
@@ -0,0 +1,387 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.io.{DataInputStream, DataOutputStream}
+import java.nio.charset.StandardCharsets
+import java.sql.{Date, Time, Timestamp}
+
+import org.apache.spark.sql.Row
+
+import scala.collection.JavaConverters._
+
+/**
+ * Class responsible for serialization and deserialization between CLR & JVM.
+ * This implementation of methods is mostly identical to the SerDe implementation in R.
+ */
+class SerDe(val tracker: JVMObjectTracker) {
+
+ def readObjectType(dis: DataInputStream): Char = {
+ dis.readByte().toChar
+ }
+
+ def readObject(dis: DataInputStream): Object = {
+ val dataType = readObjectType(dis)
+ readTypedObject(dis, dataType)
+ }
+
+ private def readTypedObject(dis: DataInputStream, dataType: Char): Object = {
+ dataType match {
+ case 'n' => null
+ case 'i' => new java.lang.Integer(readInt(dis))
+ case 'g' => new java.lang.Long(readLong(dis))
+ case 'd' => new java.lang.Double(readDouble(dis))
+ case 'b' => new java.lang.Boolean(readBoolean(dis))
+ case 'c' => readString(dis)
+ case 'e' => readMap(dis)
+ case 'r' => readBytes(dis)
+ case 'l' => readList(dis)
+ case 'D' => readDate(dis)
+ case 't' => readTime(dis)
+ case 'j' => tracker.getObject(readString(dis))
+ case 'R' => readRowArr(dis)
+ case 'O' => readObjectArr(dis)
+ case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
+ }
+ }
+
+ private def readBytes(in: DataInputStream): Array[Byte] = {
+ val len = readInt(in)
+ val out = new Array[Byte](len)
+ in.readFully(out)
+ out
+ }
+
+ def readInt(in: DataInputStream): Int = {
+ in.readInt()
+ }
+
+ private def readLong(in: DataInputStream): Long = {
+ in.readLong()
+ }
+
+ private def readDouble(in: DataInputStream): Double = {
+ in.readDouble()
+ }
+
+ private def readStringBytes(in: DataInputStream, len: Int): String = {
+ val bytes = new Array[Byte](len)
+ in.readFully(bytes)
+ val str = new String(bytes, "UTF-8")
+ str
+ }
+
+ def readString(in: DataInputStream): String = {
+ val len = in.readInt()
+ readStringBytes(in, len)
+ }
+
+ def readBoolean(in: DataInputStream): Boolean = {
+ in.readBoolean()
+ }
+
+ private def readDate(in: DataInputStream): Date = {
+ Date.valueOf(readString(in))
+ }
+
+ private def readTime(in: DataInputStream): Timestamp = {
+ val seconds = in.readDouble()
+ val sec = Math.floor(seconds).toLong
+ val t = new Timestamp(sec * 1000L)
+ t.setNanos(((seconds - sec) * 1e9).toInt)
+ t
+ }
+
+ private def readRow(in: DataInputStream): Row = {
+ val len = readInt(in)
+ Row.fromSeq((0 until len).map(_ => readObject(in)))
+ }
+
+ private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readBytes(in)).toArray
+ }
+
+ private def readIntArr(in: DataInputStream): Array[Int] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readInt(in)).toArray
+ }
+
+ private def readLongArr(in: DataInputStream): Array[Long] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readLong(in)).toArray
+ }
+
+ private def readDoubleArr(in: DataInputStream): Array[Double] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readDouble(in)).toArray
+ }
+
+ private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readDoubleArr(in)).toArray
+ }
+
+ private def readBooleanArr(in: DataInputStream): Array[Boolean] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readBoolean(in)).toArray
+ }
+
+ private def readStringArr(in: DataInputStream): Array[String] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readString(in)).toArray
+ }
+
+ private def readRowArr(in: DataInputStream): java.util.List[Row] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readRow(in)).toList.asJava
+ }
+
+ private def readObjectArr(in: DataInputStream): Seq[Any] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readObject(in))
+ }
+
+ private def readList(dis: DataInputStream): Array[_] = {
+ val arrType = readObjectType(dis)
+ arrType match {
+ case 'i' => readIntArr(dis)
+ case 'g' => readLongArr(dis)
+ case 'c' => readStringArr(dis)
+ case 'd' => readDoubleArr(dis)
+ case 'A' => readDoubleArrArr(dis)
+ case 'b' => readBooleanArr(dis)
+ case 'j' => readStringArr(dis).map(x => tracker.getObject(x))
+ case 'r' => readBytesArr(dis)
+ case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
+ }
+ }
+
+ private def readMap(in: DataInputStream): java.util.Map[Object, Object] = {
+ val len = readInt(in)
+ if (len > 0) {
+ val keysType = readObjectType(in)
+ val keysLen = readInt(in)
+ val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType))
+
+ val valuesLen = readInt(in)
+ val values = (0 until valuesLen).map(_ => {
+ val valueType = readObjectType(in)
+ readTypedObject(in, valueType)
+ })
+ keys.zip(values).toMap.asJava
+ } else {
+ new java.util.HashMap[Object, Object]()
+ }
+ }
+
+ // Using the same mapping as SparkR implementation for now
+ // Methods to write out data from Java to .NET.
+ //
+ // Type mapping from Java to .NET:
+ //
+ // void -> NULL
+ // Int -> integer
+ // String -> character
+ // Boolean -> logical
+ // Float -> double
+ // Double -> double
+ // Long -> long
+ // Array[Byte] -> raw
+ // Date -> Date
+ // Time -> POSIXct
+ //
+ // Array[T] -> list()
+ // Object -> jobj
+
+ def writeType(dos: DataOutputStream, typeStr: String): Unit = {
+ typeStr match {
+ case "void" => dos.writeByte('n')
+ case "character" => dos.writeByte('c')
+ case "double" => dos.writeByte('d')
+ case "doublearray" => dos.writeByte('A')
+ case "long" => dos.writeByte('g')
+ case "integer" => dos.writeByte('i')
+ case "logical" => dos.writeByte('b')
+ case "date" => dos.writeByte('D')
+ case "time" => dos.writeByte('t')
+ case "raw" => dos.writeByte('r')
+ case "list" => dos.writeByte('l')
+ case "jobj" => dos.writeByte('j')
+ case _ => throw new IllegalArgumentException(s"Invalid type $typeStr")
+ }
+ }
+
+ def writeObject(dos: DataOutputStream, value: Object): Unit = {
+ if (value == null || value == Unit) {
+ writeType(dos, "void")
+ } else {
+ value.getClass.getName match {
+ case "java.lang.String" =>
+ writeType(dos, "character")
+ writeString(dos, value.asInstanceOf[String])
+ case "float" | "java.lang.Float" =>
+ writeType(dos, "double")
+ writeDouble(dos, value.asInstanceOf[Float].toDouble)
+ case "double" | "java.lang.Double" =>
+ writeType(dos, "double")
+ writeDouble(dos, value.asInstanceOf[Double])
+ case "long" | "java.lang.Long" =>
+ writeType(dos, "long")
+ writeLong(dos, value.asInstanceOf[Long])
+ case "int" | "java.lang.Integer" =>
+ writeType(dos, "integer")
+ writeInt(dos, value.asInstanceOf[Int])
+ case "boolean" | "java.lang.Boolean" =>
+ writeType(dos, "logical")
+ writeBoolean(dos, value.asInstanceOf[Boolean])
+ case "java.sql.Date" =>
+ writeType(dos, "date")
+ writeDate(dos, value.asInstanceOf[Date])
+ case "java.sql.Time" =>
+ writeType(dos, "time")
+ writeTime(dos, value.asInstanceOf[Time])
+ case "java.sql.Timestamp" =>
+ writeType(dos, "time")
+ writeTime(dos, value.asInstanceOf[Timestamp])
+ case "[B" =>
+ writeType(dos, "raw")
+ writeBytes(dos, value.asInstanceOf[Array[Byte]])
+ // TODO: Types not handled right now include
+ // byte, char, short, float
+
+ // Handle arrays
+ case "[Ljava.lang.String;" =>
+ writeType(dos, "list")
+ writeStringArr(dos, value.asInstanceOf[Array[String]])
+ case "[I" =>
+ writeType(dos, "list")
+ writeIntArr(dos, value.asInstanceOf[Array[Int]])
+ case "[J" =>
+ writeType(dos, "list")
+ writeLongArr(dos, value.asInstanceOf[Array[Long]])
+ case "[D" =>
+ writeType(dos, "list")
+ writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
+ case "[[D" =>
+ writeType(dos, "list")
+ writeDoubleArrArr(dos, value.asInstanceOf[Array[Array[Double]]])
+ case "[Z" =>
+ writeType(dos, "list")
+ writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
+ case "[[B" =>
+ writeType(dos, "list")
+ writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]])
+ case otherName =>
+ // Handle array of objects
+ if (otherName.startsWith("[L")) {
+ val objArr = value.asInstanceOf[Array[Object]]
+ writeType(dos, "list")
+ writeType(dos, "jobj")
+ dos.writeInt(objArr.length)
+ objArr.foreach(o => writeJObj(dos, o))
+ } else {
+ writeType(dos, "jobj")
+ writeJObj(dos, value)
+ }
+ }
+ }
+ }
+
+ def writeInt(out: DataOutputStream, value: Int): Unit = {
+ out.writeInt(value)
+ }
+
+ def writeLong(out: DataOutputStream, value: Long): Unit = {
+ out.writeLong(value)
+ }
+
+ private def writeDouble(out: DataOutputStream, value: Double): Unit = {
+ out.writeDouble(value)
+ }
+
+ private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = {
+ out.writeBoolean(value)
+ }
+
+ private def writeDate(out: DataOutputStream, value: Date): Unit = {
+ writeString(out, value.toString)
+ }
+
+ private def writeTime(out: DataOutputStream, value: Time): Unit = {
+ out.writeDouble(value.getTime.toDouble / 1000.0)
+ }
+
+ private def writeTime(out: DataOutputStream, value: Timestamp): Unit = {
+ out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9)
+ }
+
+ def writeString(out: DataOutputStream, value: String): Unit = {
+ val utf8 = value.getBytes(StandardCharsets.UTF_8)
+ val len = utf8.length
+ out.writeInt(len)
+ out.write(utf8, 0, len)
+ }
+
+ private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = {
+ out.writeInt(value.length)
+ out.write(value)
+ }
+
+ def writeJObj(out: DataOutputStream, value: Object): Unit = {
+ val objId = tracker.put(value)
+ writeString(out, objId)
+ }
+
+ private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = {
+ writeType(out, "integer")
+ out.writeInt(value.length)
+ value.foreach(v => out.writeInt(v))
+ }
+
+ private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = {
+ writeType(out, "long")
+ out.writeInt(value.length)
+ value.foreach(v => out.writeLong(v))
+ }
+
+ private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = {
+ writeType(out, "double")
+ out.writeInt(value.length)
+ value.foreach(v => out.writeDouble(v))
+ }
+
+ private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = {
+ writeType(out, "doublearray")
+ out.writeInt(value.length)
+ value.foreach(v => writeDoubleArr(out, v))
+ }
+
+ private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = {
+ writeType(out, "logical")
+ out.writeInt(value.length)
+ value.foreach(v => writeBoolean(out, v))
+ }
+
+ private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = {
+ writeType(out, "character")
+ out.writeInt(value.length)
+ value.foreach(v => writeString(out, v))
+ }
+
+ private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = {
+ writeType(out, "raw")
+ out.writeInt(value.length)
+ value.foreach(v => writeBytes(out, v))
+ }
+}
+
+private object SerializationFormats {
+ val BYTE = "byte"
+ val STRING = "string"
+ val ROW = "row"
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala
new file mode 100644
index 000000000..50551a7d9
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.util.concurrent.{ExecutorService, Executors}
+
+import scala.collection.mutable
+
+/**
+ * Pool of thread executors. There should be a 1-1 correspondence between C# threads
+ * and Java threads.
+ */
+object ThreadPool {
+
+ /**
+ * Map from (processId, threadId) to corresponding executor.
+ */
+ private val executors: mutable.HashMap[(Int, Int), ExecutorService] =
+ new mutable.HashMap[(Int, Int), ExecutorService]()
+
+ /**
+ * Run some code on a particular thread.
+ * @param processId Integer id of the process.
+ * @param threadId Integer id of the thread.
+ * @param task Function to run on the thread.
+ */
+ def run(processId: Int, threadId: Int, task: () => Unit): Unit = {
+ val executor = getOrCreateExecutor(processId, threadId)
+ val future = executor.submit(new Runnable {
+ override def run(): Unit = task()
+ })
+
+ future.get()
+ }
+
+ /**
+ * Try to delete a particular thread.
+ * @param processId Integer id of the process.
+ * @param threadId Integer id of the thread.
+ * @return True if successful, false if thread does not exist.
+ */
+ def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized {
+ executors.remove((processId, threadId)) match {
+ case Some(executorService) =>
+ executorService.shutdown()
+ true
+ case None => false
+ }
+ }
+
+ /**
+ * Shutdown any running ExecutorServices.
+ */
+ def shutdown(): Unit = synchronized {
+ executors.foreach(_._2.shutdown())
+ executors.clear()
+ }
+
+ /**
+ * Get the executor if it exists, otherwise create a new one.
+ * @param processId Integer id of the process.
+ * @param threadId Integer id of the thread.
+ * @return The new or existing executor with the given id.
+ */
+ private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized {
+ executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor)
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala
new file mode 100644
index 000000000..4551a70bd
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.deploy.dotnet
+
+import org.apache.spark.SparkException
+
+/**
+ * This exception type describes an exception thrown by a .NET user application.
+ *
+ * @param exitCode Exit code returned by the .NET application.
+ * @param dotNetStackTrace Stacktrace extracted from .NET application logs.
+ */
+private[spark] class DotNetUserAppException(exitCode: Int, dotNetStackTrace: Option[String])
+ extends SparkException(
+ dotNetStackTrace match {
+ case None => s"User application exited with $exitCode"
+ case Some(e) => s"User application exited with $exitCode and .NET exception: $e"
+ })
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala
new file mode 100644
index 000000000..894f21c21
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.deploy.dotnet
+
+import java.io.File
+import java.net.URI
+import java.nio.file.attribute.PosixFilePermissions
+import java.nio.file.{FileSystems, Files, Paths}
+import java.util.Locale
+import java.util.concurrent.{Semaphore, TimeUnit}
+
+import org.apache.commons.io.FilenameUtils
+import org.apache.commons.io.output.TeeOutputStream
+import org.apache.hadoop.fs.Path
+import org.apache.spark
+import org.apache.spark.api.dotnet.DotnetBackend
+import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.dotnet.Dotnet.{
+ DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK,
+ ERROR_BUFFER_SIZE, ERROR_REDIRECITON_ENABLED
+}
+import org.apache.spark.util.dotnet.{Utils => DotnetUtils}
+import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils}
+import org.apache.spark.{SecurityManager, SparkConf, SparkUserAppException}
+
+import scala.collection.JavaConverters._
+import scala.io.StdIn
+import scala.util.Try
+
+/**
+ * DotnetRunner class used to launch Spark .NET applications using spark-submit.
+ * It executes .NET application as a subprocess and then has it connect back to
+ * the JVM to access system properties etc.
+ */
+object DotnetRunner extends Logging {
+ private val DEBUG_PORT = 5567
+ private val supportedSparkMajorMinorVersionPrefix = "3.3"
+ private val supportedSparkVersions = Set[String]("3.2.0", "3.2.1", "3.2.2", "3.2.3", "3.3.0")
+
+ val SPARK_VERSION = DotnetUtils.normalizeSparkVersion(spark.SPARK_VERSION)
+
+ def main(args: Array[String]): Unit = {
+ if (args.length == 0) {
+ throw new IllegalArgumentException("At least one argument is expected.")
+ }
+
+ DotnetUtils.validateSparkVersions(
+ sys.props
+ .getOrElse(
+ DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.key,
+ DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.defaultValue.get.toString)
+ .toBoolean,
+ spark.SPARK_VERSION,
+ SPARK_VERSION,
+ supportedSparkMajorMinorVersionPrefix,
+ supportedSparkVersions)
+
+ val settings = initializeSettings(args)
+
+ // Determines if this needs to be run in debug mode.
+ // In debug mode this runner will not launch a .NET process.
+ val runInDebugMode = settings._1
+ @volatile var dotnetBackendPortNumber = settings._2
+ var dotnetExecutable = ""
+ var otherArgs: Array[String] = null
+
+ if (!runInDebugMode) {
+ if (args(0).toLowerCase(Locale.ROOT).endsWith(".zip")) {
+ var zipFileName = args(0)
+ val zipFileUri = Try(new URI(zipFileName)).getOrElse(new File(zipFileName).toURI)
+ val workingDir = new File("").getAbsoluteFile
+ val driverDir = new File(workingDir, FilenameUtils.getBaseName(zipFileUri.getPath()))
+
+ // Standalone cluster mode where .NET application is remotely located.
+ if (zipFileUri.getScheme() != "file") {
+ zipFileName = downloadDriverFile(zipFileName, workingDir.getAbsolutePath).getName
+ }
+
+ logInfo(s"Unzipping .NET driver $zipFileName to $driverDir")
+ DotnetUtils.unzip(new File(zipFileName), driverDir)
+
+ // Reuse windows-specific formatting in PythonRunner.
+ dotnetExecutable = PythonRunner.formatPath(resolveDotnetExecutable(driverDir, args(1)))
+ otherArgs = args.slice(2, args.length)
+ } else {
+ // Reuse windows-specific formatting in PythonRunner.
+ dotnetExecutable = PythonRunner.formatPath(args(0))
+ otherArgs = args.slice(1, args.length)
+ }
+ } else {
+ otherArgs = args.slice(1, args.length)
+ }
+
+ val processParameters = new java.util.ArrayList[String]
+ processParameters.add(dotnetExecutable)
+ otherArgs.foreach(arg => processParameters.add(arg))
+
+ logInfo(s"Starting DotnetBackend with $dotnetExecutable.")
+
+ // Time to wait for DotnetBackend to initialize in seconds.
+ val backendTimeout = sys.env.getOrElse("DOTNETBACKEND_TIMEOUT", "120").toInt
+
+ // Launch a DotnetBackend server for the .NET process to connect to; this will let it see our
+ // Java system properties etc.
+ val dotnetBackend = new DotnetBackend()
+ val initialized = new Semaphore(0)
+ val dotnetBackendThread = new Thread("DotnetBackend") {
+ override def run() {
+ // need to get back dotnetBackendPortNumber because if the value passed to init is 0
+ // the port number is dynamically assigned in the backend
+ dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber)
+ logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber")
+ initialized.release()
+ dotnetBackend.run()
+ }
+ }
+
+ dotnetBackendThread.start()
+
+ if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) {
+ if (!runInDebugMode) {
+ var returnCode = -1
+ var process: Process = null
+ val enableLogRedirection: Boolean = sys.props
+ .getOrElse(
+ ERROR_REDIRECITON_ENABLED.key,
+ ERROR_REDIRECITON_ENABLED.defaultValue.get.toString).toBoolean
+ val stderrBuffer: Option[CircularBuffer] = Option(enableLogRedirection).collect {
+ case true => new CircularBuffer(
+ sys.props.getOrElse(
+ ERROR_BUFFER_SIZE.key,
+ ERROR_BUFFER_SIZE.defaultValue.get.toString).toInt)
+ }
+
+ try {
+ val builder = new ProcessBuilder(processParameters)
+ val env = builder.environment()
+ env.put("DOTNETBACKEND_PORT", dotnetBackendPortNumber.toString)
+
+ for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) {
+ env.put(key, value)
+ logInfo(s"Adding key=$key and value=$value to environment")
+ }
+ builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
+ process = builder.start()
+
+ // Redirect stdin of JVM process to stdin of .NET process.
+ new RedirectThread(System.in, process.getOutputStream, "redirect JVM input").start()
+ // Redirect stdout and stderr of .NET process to System.out and to buffer
+ // if log direction is enabled. If not, redirect only to System.out.
+ new RedirectThread(
+ process.getInputStream,
+ stderrBuffer match {
+ case Some(buffer) => new TeeOutputStream(System.out, buffer)
+ case _ => System.out
+ },
+ "redirect .NET stdout and stderr").start()
+
+ process.waitFor()
+ } catch {
+ case t: Throwable =>
+ logThrowable(t)
+ } finally {
+ returnCode = closeDotnetProcess(process)
+ closeBackend(dotnetBackend)
+ }
+ if (returnCode != 0) {
+ if (stderrBuffer.isDefined) {
+ throw new DotNetUserAppException(returnCode, Some(stderrBuffer.get.toString))
+ } else {
+ throw new SparkUserAppException(returnCode)
+ }
+ } else {
+ logInfo(s".NET application exited successfully")
+ }
+ // TODO: The following is causing the following error:
+ // INFO ApplicationMaster: Final app status: FAILED, exitCode: 16,
+ // (reason: Shutdown hook called before final status was reported.)
+ // DotnetUtils.exit(returnCode)
+ } else {
+ // scalastyle:off println
+ println("***********************************************************************")
+ println("* .NET Backend running debug mode. Press enter to exit *")
+ println("***********************************************************************")
+ // scalastyle:on println
+
+ StdIn.readLine()
+ closeBackend(dotnetBackend)
+ DotnetUtils.exit(0)
+ }
+ } else {
+ logError(s"DotnetBackend did not initialize in $backendTimeout seconds")
+ DotnetUtils.exit(-1)
+ }
+ }
+
+ // When the executable is downloaded as part of zip file, check if the file exists
+ // after zip file is unzipped under the given dir. Once it is found, change the
+ // permission to executable (only for Unix systems, since the zip file may have been
+ // created under Windows. Finally, the absolute path for the executable is returned.
+ private def resolveDotnetExecutable(dir: File, dotnetExecutable: String): String = {
+ val path = Paths.get(dir.getAbsolutePath, dotnetExecutable)
+ val resolvedExecutable = if (Files.isRegularFile(path)) {
+ path.toAbsolutePath.toString
+ } else {
+ Files
+ .walk(FileSystems.getDefault.getPath(dir.getAbsolutePath))
+ .iterator()
+ .asScala
+ .find(path => Files.isRegularFile(path) && path.getFileName.toString == dotnetExecutable) match {
+ case Some(path) => path.toAbsolutePath.toString
+ case None =>
+ throw new IllegalArgumentException(
+ s"Failed to find $dotnetExecutable under ${dir.getAbsolutePath}")
+ }
+ }
+
+ if (DotnetUtils.supportPosix) {
+ Files.setPosixFilePermissions(
+ Paths.get(resolvedExecutable),
+ PosixFilePermissions.fromString("rwxr-xr-x"))
+ }
+
+ resolvedExecutable
+ }
+
+ /**
+ * Download HDFS file into the supplied directory and return its local path.
+ * Will throw an exception if there are errors during downloading.
+ */
+ private def downloadDriverFile(hdfsFilePath: String, driverDir: String): File = {
+ val sparkConf = new SparkConf()
+ val filePath = new Path(hdfsFilePath)
+
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf)
+ val jarFileName = filePath.getName
+ val localFile = new File(driverDir, jarFileName)
+
+ if (!localFile.exists()) { // May already exist if running multiple workers on one node
+ logInfo(s"Copying user file $filePath to $driverDir")
+ Utils.fetchFile(
+ hdfsFilePath,
+ new File(driverDir),
+ sparkConf,
+ hadoopConf,
+ System.currentTimeMillis(),
+ useCache = false)
+ }
+
+ if (!localFile.exists()) {
+ throw new Exception(s"Did not see expected $jarFileName in $driverDir")
+ }
+
+ localFile
+ }
+
+ private def closeBackend(dotnetBackend: DotnetBackend): Unit = {
+ logInfo("Closing DotnetBackend")
+ dotnetBackend.close()
+ }
+
+ private def closeDotnetProcess(dotnetProcess: Process): Int = {
+ if (dotnetProcess == null) {
+ return -1
+ } else if (!dotnetProcess.isAlive) {
+ return dotnetProcess.exitValue()
+ }
+
+ // Try to (gracefully on Linux) kill the process and resort to force if interrupted
+ var returnCode = -1
+ logInfo("Closing .NET process")
+ try {
+ dotnetProcess.destroy()
+ returnCode = dotnetProcess.waitFor()
+ } catch {
+ case _: InterruptedException =>
+ logInfo(
+ "Thread interrupted while waiting for graceful close. Forcefully closing .NET process")
+ returnCode = dotnetProcess.destroyForcibly().waitFor()
+ case t: Throwable =>
+ logThrowable(t)
+ }
+
+ returnCode
+ }
+
+ private def initializeSettings(args: Array[String]): (Boolean, Int) = {
+ val runInDebugMode = (args.length == 1 || args.length == 2) && args(0).equalsIgnoreCase(
+ "debug")
+ var portNumber = 0
+ if (runInDebugMode) {
+ if (args.length == 1) {
+ portNumber = DEBUG_PORT
+ } else if (args.length == 2) {
+ portNumber = Integer.parseInt(args(1))
+ }
+ }
+
+ (runInDebugMode, portNumber)
+ }
+
+ private def logThrowable(throwable: Throwable): Unit =
+ logError(s"${throwable.getMessage} \n ${throwable.getStackTrace.mkString("\n")}")
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala
new file mode 100644
index 000000000..18ba4c6e5
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.internal.config.dotnet
+
+import org.apache.spark.internal.config.ConfigBuilder
+
+private[spark] object Dotnet {
+ val DOTNET_NUM_BACKEND_THREADS = ConfigBuilder("spark.dotnet.numDotnetBackendThreads").intConf
+ .createWithDefault(10)
+
+ val DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK =
+ ConfigBuilder("spark.dotnet.ignoreSparkPatchVersionCheck").booleanConf
+ .createWithDefault(false)
+
+ val ERROR_REDIRECITON_ENABLED =
+ ConfigBuilder("spark.nonjvm.error.forwarding.enabled").booleanConf
+ .createWithDefault(false)
+
+ val ERROR_BUFFER_SIZE =
+ ConfigBuilder("spark.nonjvm.error.buffer.size")
+ .intConf
+ .checkValue(_ >= 0, "The error buffer size must not be negative")
+ .createWithDefault(10240)
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
new file mode 100644
index 000000000..3e3c3e0e3
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
@@ -0,0 +1,26 @@
+
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.mllib.api.dotnet
+
+import org.apache.spark.ml._
+import scala.collection.JavaConverters._
+
+/** MLUtils object that hosts helper functions
+ * related to ML usage
+ */
+object MLUtils {
+
+ /** A helper function to let pipeline accept java.util.ArrayList
+ * format stages in scala code
+ * @param pipeline - The pipeline to be set stages
+ * @param value - A java.util.ArrayList of PipelineStages to be set as stages
+ * @return The pipeline
+ */
+ def setPipelineStages(pipeline: Pipeline, value: java.util.ArrayList[_ <: PipelineStage]): Pipeline =
+ pipeline.setStages(value.asScala.toArray)
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala
new file mode 100644
index 000000000..5d06d4304
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.sql.api.dotnet
+
+import org.apache.spark.api.dotnet.CallbackClient
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.streaming.DataStreamWriter
+
+class DotnetForeachBatchFunction(callbackClient: CallbackClient, callbackId: Int) extends Logging {
+ def call(batchDF: DataFrame, batchId: Long): Unit =
+ callbackClient.send(
+ callbackId,
+ (dos, serDe) => {
+ serDe.writeJObj(dos, batchDF)
+ serDe.writeLong(dos, batchId)
+ })
+}
+
+object DotnetForeachBatchHelper {
+ def callForeachBatch(client: Option[CallbackClient], dsw: DataStreamWriter[Row], callbackId: Int): Unit = {
+ val dotnetForeachFunc = client match {
+ case Some(value) => new DotnetForeachBatchFunction(value, callbackId)
+ case None => throw new Exception("CallbackClient is null.")
+ }
+
+ dsw.foreachBatch(dotnetForeachFunc.call _)
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala
new file mode 100644
index 000000000..b5e97289a
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.sql.api.dotnet
+
+import java.util.{List => JList, Map => JMap}
+
+import org.apache.spark.api.python.{PythonAccumulatorV2, PythonBroadcast, PythonFunction}
+import org.apache.spark.broadcast.Broadcast
+
+object SQLUtils {
+
+ /**
+ * Exposes createPythonFunction to the .NET client to enable registering UDFs.
+ */
+ def createPythonFunction(
+ command: Array[Byte],
+ envVars: JMap[String, String],
+ pythonIncludes: JList[String],
+ pythonExec: String,
+ pythonVersion: String,
+ broadcastVars: JList[Broadcast[PythonBroadcast]],
+ accumulator: PythonAccumulatorV2): PythonFunction = {
+
+ PythonFunction(
+ command,
+ envVars,
+ pythonIncludes,
+ pythonExec,
+ pythonVersion,
+ broadcastVars,
+ accumulator)
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/test/TestUtils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/test/TestUtils.scala
new file mode 100644
index 000000000..1cd45aa95
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/test/TestUtils.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.sql.test
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.execution.streaming.MemoryStream
+
+object TestUtils {
+
+ /**
+ * Helper method to create typed MemoryStreams intended for use in unit tests.
+ * @param sqlContext The SQLContext.
+ * @param streamType The type of memory stream to create. This string is the `Name`
+ * property of the dotnet type.
+ * @return A typed MemoryStream.
+ */
+ def createMemoryStream(implicit sqlContext: SQLContext, streamType: String): MemoryStream[_] = {
+ import sqlContext.implicits._
+
+ streamType match {
+ case "Int32" => MemoryStream[Int]
+ case "String" => MemoryStream[String]
+ case _ => throw new Exception(s"$streamType not supported")
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/util/dotnet/Utils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/util/dotnet/Utils.scala
new file mode 100644
index 000000000..f9400789a
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/util/dotnet/Utils.scala
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.util.dotnet
+
+import java.io._
+import java.nio.file.attribute.PosixFilePermission
+import java.nio.file.attribute.PosixFilePermission._
+import java.nio.file.{FileSystems, Files}
+import java.util.{Timer, TimerTask}
+
+import org.apache.commons.compress.archivers.zip.{ZipArchiveEntry, ZipArchiveOutputStream, ZipFile}
+import org.apache.commons.io.{FileUtils, IOUtils}
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK
+
+import scala.collection.JavaConverters._
+import scala.collection.Set
+
+/**
+ * Utility methods.
+ */
+object Utils extends Logging {
+ private val posixFilePermissions = Array(
+ OWNER_READ,
+ OWNER_WRITE,
+ OWNER_EXECUTE,
+ GROUP_READ,
+ GROUP_WRITE,
+ GROUP_EXECUTE,
+ OTHERS_READ,
+ OTHERS_WRITE,
+ OTHERS_EXECUTE)
+
+ val supportPosix: Boolean =
+ FileSystems.getDefault.supportedFileAttributeViews().contains("posix")
+
+ /**
+ * Compress all files under given directory into one zip file and drop it to the target directory
+ *
+ * @param sourceDir source directory to zip
+ * @param targetZipFile target zip file
+ */
+ def zip(sourceDir: File, targetZipFile: File): Unit = {
+ var fos: FileOutputStream = null
+ var zos: ZipArchiveOutputStream = null
+ try {
+ fos = new FileOutputStream(targetZipFile)
+ zos = new ZipArchiveOutputStream(fos)
+
+ val sourcePath = sourceDir.toPath
+ FileUtils.listFiles(sourceDir, null, true).asScala.foreach { file =>
+ var in: FileInputStream = null
+ try {
+ val path = file.toPath
+ val entry = new ZipArchiveEntry(sourcePath.relativize(path).toString)
+ if (supportPosix) {
+ entry.setUnixMode(
+ permissionsToMode(Files.getPosixFilePermissions(path).asScala)
+ | (if (entry.getName.endsWith(".exe")) 0x1ED else 0x1A4))
+ } else if (entry.getName.endsWith(".exe")) {
+ entry.setUnixMode(0x1ED) // 755
+ } else {
+ entry.setUnixMode(0x1A4) // 644
+ }
+ zos.putArchiveEntry(entry)
+
+ in = new FileInputStream(file)
+ IOUtils.copy(in, zos)
+ zos.closeArchiveEntry()
+ } finally {
+ IOUtils.closeQuietly(in)
+ }
+ }
+ } finally {
+ IOUtils.closeQuietly(zos)
+ IOUtils.closeQuietly(fos)
+ }
+ }
+
+ /**
+ * Unzip a file to the given directory
+ *
+ * @param file file to be unzipped
+ * @param targetDir target directory
+ */
+ def unzip(file: File, targetDir: File): Unit = {
+ var zipFile: ZipFile = null
+ try {
+ targetDir.mkdirs()
+ zipFile = new ZipFile(file)
+ zipFile.getEntries.asScala.foreach { entry =>
+ val targetFile = new File(targetDir, entry.getName)
+
+ if (targetFile.exists()) {
+ logWarning(
+ s"Target file/directory $targetFile already exists. Skip it for now. " +
+ s"Make sure this is expected.")
+ } else {
+ if (entry.isDirectory) {
+ targetFile.mkdirs()
+ } else {
+ targetFile.getParentFile.mkdirs()
+ val input = zipFile.getInputStream(entry)
+ val output = new FileOutputStream(targetFile)
+ IOUtils.copy(input, output)
+ IOUtils.closeQuietly(input)
+ IOUtils.closeQuietly(output)
+ if (supportPosix) {
+ val permissions = modeToPermissions(entry.getUnixMode)
+ // When run in Unix system, permissions will be empty, thus skip
+ // setting the empty permissions (which will empty the previous permissions).
+ if (permissions.nonEmpty) {
+ Files.setPosixFilePermissions(targetFile.toPath, permissions.asJava)
+ }
+ }
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("exception caught during decompression:" + e)
+ } finally {
+ ZipFile.closeQuietly(zipFile)
+ }
+ }
+
+ /**
+ * Exits the JVM, trying to do it nicely, otherwise doing it nastily.
+ *
+ * @param status the exit status, zero for OK, non-zero for error
+ * @param maxDelayMillis the maximum delay in milliseconds
+ */
+ def exit(status: Int, maxDelayMillis: Long) {
+ try {
+ logInfo(s"Utils.exit() with status: $status, maxDelayMillis: $maxDelayMillis")
+
+ // setup a timer, so if nice exit fails, the nasty exit happens
+ val timer = new Timer()
+ timer.schedule(new TimerTask() {
+
+ override def run() {
+ Runtime.getRuntime.halt(status)
+ }
+ }, maxDelayMillis)
+ // try to exit nicely
+ System.exit(status);
+ } catch {
+ // exit nastily if we have a problem
+ case _: Throwable => Runtime.getRuntime.halt(status)
+ } finally {
+ // should never get here
+ Runtime.getRuntime.halt(status)
+ }
+ }
+
+ /**
+ * Exits the JVM, trying to do it nicely, wait 1 second
+ *
+ * @param status the exit status, zero for OK, non-zero for error
+ */
+ def exit(status: Int): Unit = {
+ exit(status, 1000)
+ }
+
+ /**
+ * Normalize the Spark version by taking the first three numbers.
+ * For example:
+ * x.y.z => x.y.z
+ * x.y.z.xxx.yyy => x.y.z
+ * x.y => x.y
+ *
+ * @param version the Spark version to normalize
+ * @return Normalized Spark version.
+ */
+ def normalizeSparkVersion(version: String): String = {
+ version
+ .split('.')
+ .take(3)
+ .zipWithIndex
+ .map({
+ case (element, index) => {
+ index match {
+ case 2 => element.split("\\D+").lift(0).getOrElse("")
+ case _ => element
+ }
+ }
+ })
+ .mkString(".")
+ }
+
+ /**
+ * Validates the normalized spark version by verifying:
+ * - Spark version starts with sparkMajorMinorVersionPrefix.
+ * - If ignoreSparkPatchVersion is
+ * - true: valid
+ * - false: check if the spark version is in supportedSparkVersions.
+ * @param ignoreSparkPatchVersion Ignore spark patch version.
+ * @param sparkVersion The spark version.
+ * @param normalizedSparkVersion: The normalized spark version.
+ * @param supportedSparkMajorMinorVersionPrefix The spark major and minor version to validate against.
+ * @param supportedSparkVersions The set of supported spark versions.
+ */
+ def validateSparkVersions(
+ ignoreSparkPatchVersion: Boolean,
+ sparkVersion: String,
+ normalizedSparkVersion: String,
+ supportedSparkMajorMinorVersionPrefix: String,
+ supportedSparkVersions: Set[String]): Unit = {
+ if (!normalizedSparkVersion.startsWith(s"$supportedSparkMajorMinorVersionPrefix.")) {
+ throw new IllegalArgumentException(
+ s"Unsupported spark version used: '$sparkVersion'. " +
+ s"Normalized spark version used: '$normalizedSparkVersion'. " +
+ s"Supported spark major.minor version: '$supportedSparkMajorMinorVersionPrefix'.")
+ } else if (ignoreSparkPatchVersion) {
+ logWarning(
+ s"Ignoring spark patch version. Spark version used: '$sparkVersion'. " +
+ s"Normalized spark version used: '$normalizedSparkVersion'. " +
+ s"Spark major.minor prefix used: '$supportedSparkMajorMinorVersionPrefix'.")
+ } else if (!supportedSparkVersions(normalizedSparkVersion)) {
+ val supportedVersions = supportedSparkVersions.toSeq.sorted.mkString(", ")
+ throw new IllegalArgumentException(
+ s"Unsupported spark version used: '$sparkVersion'. " +
+ s"Normalized spark version used: '$normalizedSparkVersion'. " +
+ s"Supported versions: '$supportedVersions'.")
+ }
+ }
+
+ private[spark] def listZipFileEntries(file: File): Array[String] = {
+ var zipFile: ZipFile = null
+ try {
+ zipFile = new ZipFile(file)
+ zipFile.getEntries.asScala.map(_.getName).toArray
+ } finally {
+ ZipFile.closeQuietly(zipFile)
+ }
+ }
+
+ private[this] def permissionsToMode(permissions: Set[PosixFilePermission]): Int = {
+ posixFilePermissions.foldLeft(0) { (mode, perm) =>
+ (mode << 1) | (if (permissions.contains(perm)) 1 else 0)
+ }
+ }
+
+ private[this] def modeToPermissions(mode: Int): Set[PosixFilePermission] = {
+ posixFilePermissions.zipWithIndex
+ .filter { case (_, i) => (mode & (0x100 >>> i)) != 0 }
+ .map(_._1)
+ .toSet
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala
new file mode 100644
index 000000000..7088537e1
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+
+package org.apache.spark.api.dotnet
+
+import Extensions._
+import org.junit.Assert._
+import org.junit.{After, Before, Test}
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+@Test
+class DotnetBackendHandlerTest {
+ private var backend: DotnetBackend = _
+ private var tracker: JVMObjectTracker = _
+ private var handler: DotnetBackendHandler = _
+
+ @Before
+ def before(): Unit = {
+ backend = new DotnetBackend
+ tracker = new JVMObjectTracker
+ handler = new DotnetBackendHandler(backend, tracker)
+ }
+
+ @After
+ def after(): Unit = {
+ backend.close()
+ }
+
+ @Test
+ def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = {
+ val message = givenMessage(m => {
+ val serDe = new SerDe(null)
+ m.writeBoolean(true) // static method
+ serDe.writeInt(m, 1) // processId
+ serDe.writeInt(m, 1) // threadId
+ serDe.writeString(m, "DotnetHandler") // class name
+ serDe.writeString(m, "connectCallback") // command (method) name
+ m.writeInt(2) // number of arguments
+ m.writeByte('c') // 1st argument type (string)
+ serDe.writeString(m, "127.0.0.1") // 1st argument value (host)
+ m.writeByte('i') // 2nd argument type (integer)
+ m.writeInt(0) // 2nd argument value (port)
+ })
+
+ val payload = handler.handleBackendRequest(message)
+ val reply = new DataInputStream(new ByteArrayInputStream(payload))
+
+ assertEquals(
+ "status code must be successful.", 0, reply.readInt())
+ assertEquals('j', reply.readByte())
+ assertEquals(1, reply.readInt())
+ val trackingId = new String(reply.readNBytes(1), "UTF-8")
+ assertEquals("1", trackingId)
+ val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull
+ assertEquals(classOf[CallbackClient], client.getClass)
+ }
+
+ private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = {
+ val buffer = new ByteArrayOutputStream()
+ func(new DataOutputStream(buffer))
+ buffer.toByteArray
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala
new file mode 100644
index 000000000..445486bbd
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+
+package org.apache.spark.api.dotnet
+
+import org.junit.Assert._
+import org.junit.{After, Before, Test}
+
+import java.net.InetAddress
+
+@Test
+class DotnetBackendTest {
+ private var backend: DotnetBackend = _
+
+ @Before
+ def before(): Unit = {
+ backend = new DotnetBackend
+ }
+
+ @After
+ def after(): Unit = {
+ backend.close()
+ }
+
+ @Test
+ def shouldNotResetCallbackClient(): Unit = {
+ // Specifying port = 0 to select port dynamically.
+ backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0)
+
+ assertTrue(backend.callbackClient.isDefined)
+ assertThrows(classOf[Exception], () => {
+ backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0)
+ })
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala
new file mode 100644
index 000000000..c6904403b
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+
+package org.apache.spark.api.dotnet
+
+import java.io.DataInputStream
+
+private[dotnet] object Extensions {
+ implicit class DataInputStreamExt(stream: DataInputStream) {
+ def readNBytes(n: Int): Array[Byte] = {
+ val buf = new Array[Byte](n)
+ stream.readFully(buf)
+ buf
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala
new file mode 100644
index 000000000..43ae79005
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import org.junit.Test
+
+@Test
+class JVMObjectTrackerTest {
+
+ @Test
+ def shouldReleaseAllReferences(): Unit = {
+ val tracker = new JVMObjectTracker
+ val firstId = tracker.put(new Object)
+ val secondId = tracker.put(new Object)
+ val thirdId = tracker.put(new Object)
+
+ tracker.clear()
+
+ assert(tracker.get(firstId).isEmpty)
+ assert(tracker.get(secondId).isEmpty)
+ assert(tracker.get(thirdId).isEmpty)
+ }
+
+ @Test
+ def shouldResetCounter(): Unit = {
+ val tracker = new JVMObjectTracker
+ val firstId = tracker.put(new Object)
+ val secondId = tracker.put(new Object)
+
+ tracker.clear()
+
+ val thirdId = tracker.put(new Object)
+
+ assert(firstId.equals("1"))
+ assert(secondId.equals("2"))
+ assert(thirdId.equals("1"))
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala
new file mode 100644
index 000000000..41401d680
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala
@@ -0,0 +1,373 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import org.apache.spark.api.dotnet.Extensions._
+import org.apache.spark.sql.Row
+import org.junit.Assert._
+import org.junit.{Before, Test}
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.sql.Date
+import scala.collection.JavaConverters._
+
+@Test
+class SerDeTest {
+ private var serDe: SerDe = _
+ private var tracker: JVMObjectTracker = _
+
+ @Before
+ def before(): Unit = {
+ tracker = new JVMObjectTracker
+ serDe = new SerDe(tracker)
+ }
+
+ @Test
+ def shouldReadNull(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('n')
+ })
+
+ assertEquals(null, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldThrowForUnsupportedTypes(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('_')
+ })
+
+ assertThrows(classOf[IllegalArgumentException], () => {
+ serDe.readObject(input)
+ })
+ }
+
+ @Test
+ def shouldReadInteger(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('i')
+ in.writeInt(42)
+ })
+
+ assertEquals(42, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadLong(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('g')
+ in.writeLong(42)
+ })
+
+ assertEquals(42L, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadDouble(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('d')
+ in.writeDouble(42.42)
+ })
+
+ assertEquals(42.42, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadBoolean(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('b')
+ in.writeBoolean(true)
+ })
+
+ assertEquals(true, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadString(): Unit = {
+ val payload = "Spark Dotnet"
+ val input = givenInput(in => {
+ in.writeByte('c')
+ in.writeInt(payload.getBytes("UTF-8").length)
+ in.write(payload.getBytes("UTF-8"))
+ })
+
+ assertEquals(payload, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadMap(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('e') // map type descriptor
+ in.writeInt(3) // size
+ in.writeByte('i') // key type
+ in.writeInt(3) // number of keys
+ in.writeInt(11) // first key
+ in.writeInt(22) // second key
+ in.writeInt(33) // third key
+ in.writeInt(3) // number of values
+ in.writeByte('b') // first value type
+ in.writeBoolean(true) // first value
+ in.writeByte('d') // second value type
+ in.writeDouble(42.42) // second value
+ in.writeByte('n') // third type & value
+ })
+
+ assertEquals(
+ mapAsJavaMap(Map(
+ 11 -> true,
+ 22 -> 42.42,
+ 33 -> null)),
+ serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadEmptyMap(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('e') // map type descriptor
+ in.writeInt(0) // size
+ })
+
+ assertEquals(mapAsJavaMap(Map()), serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadBytesArray(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('r') // byte array type descriptor
+ in.writeInt(3) // length
+ in.write(Array[Byte](1, 2, 3)) // payload
+ })
+
+ assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]])
+ }
+
+ @Test
+ def shouldReadEmptyBytesArray(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('r') // byte array type descriptor
+ in.writeInt(0) // length
+ })
+
+ assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]])
+ }
+
+ @Test
+ def shouldReadEmptyList(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('l') // type descriptor
+ in.writeByte('i') // element type
+ in.writeInt(0) // length
+ })
+
+ assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]])
+ }
+
+ @Test
+ def shouldReadList(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('l') // type descriptor
+ in.writeByte('b') // element type
+ in.writeInt(3) // length
+ in.writeBoolean(true)
+ in.writeBoolean(false)
+ in.writeBoolean(true)
+ })
+
+ assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]])
+ }
+
+ @Test
+ def shouldThrowWhenReadingListWithUnsupportedType(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('l') // type descriptor
+ in.writeByte('_') // unsupported element type
+ })
+
+ assertThrows(classOf[IllegalArgumentException], () => {
+ serDe.readObject(input)
+ })
+ }
+
+ @Test
+ def shouldReadDate(): Unit = {
+ val input = givenInput(in => {
+ val date = "2020-12-31"
+ in.writeByte('D') // type descriptor
+ in.writeInt(date.getBytes("UTF-8").length) // date string size
+ in.write(date.getBytes("UTF-8"))
+ })
+
+ assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadObject(): Unit = {
+ val trackingObject = new Object
+ tracker.put(trackingObject)
+ val input = givenInput(in => {
+ val objectIndex = "1"
+ in.writeByte('j') // type descriptor
+ in.writeInt(objectIndex.getBytes("UTF-8").length) // size
+ in.write(objectIndex.getBytes("UTF-8"))
+ })
+
+ assertSame(trackingObject, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldThrowWhenReadingNonTrackingObject(): Unit = {
+ val input = givenInput(in => {
+ val objectIndex = "42"
+ in.writeByte('j') // type descriptor
+ in.writeInt(objectIndex.getBytes("UTF-8").length) // size
+ in.write(objectIndex.getBytes("UTF-8"))
+ })
+
+ assertThrows(classOf[NoSuchElementException], () => {
+ serDe.readObject(input)
+ })
+ }
+
+ @Test
+ def shouldReadSparkRows(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('R') // type descriptor
+ in.writeInt(2) // number of rows
+ in.writeInt(1) // number of elements in 1st row
+ in.writeByte('i') // type of 1st element in 1st row
+ in.writeInt(11)
+ in.writeInt(3) // number of elements in 2st row
+ in.writeByte('b') // type of 1st element in 2nd row
+ in.writeBoolean(true)
+ in.writeByte('d') // type of 2nd element in 2nd row
+ in.writeDouble(42.24)
+ in.writeByte('g') // type of 3nd element in 2nd row
+ in.writeLong(99)
+ })
+
+ assertEquals(
+ seqAsJavaList(Seq(
+ Row.fromSeq(Seq(11)),
+ Row.fromSeq(Seq(true, 42.24, 99)))),
+ serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadArrayOfObjects(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('O') // type descriptor
+ in.writeInt(2) // number of elements
+ in.writeByte('i') // type of 1st element
+ in.writeInt(42)
+ in.writeByte('b') // type of 2nd element
+ in.writeBoolean(true)
+ })
+
+ assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]])
+ }
+
+ @Test
+ def shouldWriteNull(): Unit = {
+ val in = whenOutput(out => {
+ serDe.writeObject(out, null)
+ serDe.writeObject(out, Unit)
+ })
+
+ assertEquals(in.readByte(), 'n')
+ assertEquals(in.readByte(), 'n')
+ assertEndOfStream(in)
+ }
+
+ @Test
+ def shouldWriteString(): Unit = {
+ val sparkDotnet = "Spark Dotnet"
+ val in = whenOutput(out => {
+ serDe.writeObject(out, sparkDotnet)
+ })
+
+ assertEquals(in.readByte(), 'c') // object type
+ assertEquals(in.readInt(), sparkDotnet.length) // length
+ assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8"))
+ assertEndOfStream(in)
+ }
+
+ @Test
+ def shouldWritePrimitiveTypes(): Unit = {
+ val in = whenOutput(out => {
+ serDe.writeObject(out, 42.24f.asInstanceOf[Object])
+ serDe.writeObject(out, 42L.asInstanceOf[Object])
+ serDe.writeObject(out, 42.asInstanceOf[Object])
+ serDe.writeObject(out, true.asInstanceOf[Object])
+ })
+
+ assertEquals(in.readByte(), 'd')
+ assertEquals(in.readDouble(), 42.24F, 0.000001)
+ assertEquals(in.readByte(), 'g')
+ assertEquals(in.readLong(), 42L)
+ assertEquals(in.readByte(), 'i')
+ assertEquals(in.readInt(), 42)
+ assertEquals(in.readByte(), 'b')
+ assertEquals(in.readBoolean(), true)
+ assertEndOfStream(in)
+ }
+
+ @Test
+ def shouldWriteDate(): Unit = {
+ val date = "2020-12-31"
+ val in = whenOutput(out => {
+ serDe.writeObject(out, Date.valueOf(date))
+ })
+
+ assertEquals(in.readByte(), 'D') // type
+ assertEquals(in.readInt(), 10) // size
+ assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content
+ }
+
+ @Test
+ def shouldWriteCustomObjects(): Unit = {
+ val customObject = new Object
+ val in = whenOutput(out => {
+ serDe.writeObject(out, customObject)
+ })
+
+ assertEquals(in.readByte(), 'j')
+ assertEquals(in.readInt(), 1)
+ assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8"))
+ assertSame(tracker.get("1").get, customObject)
+ }
+
+ @Test
+ def shouldWriteArrayOfCustomObjects(): Unit = {
+ val payload = Array(new Object, new Object)
+ val in = whenOutput(out => {
+ serDe.writeObject(out, payload)
+ })
+
+ assertEquals(in.readByte(), 'l') // array type
+ assertEquals(in.readByte(), 'j') // type of element in array
+ assertEquals(in.readInt(), 2) // array length
+ assertEquals(in.readInt(), 1) // size of 1st element's identifiers
+ assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element
+ assertEquals(in.readInt(), 1) // size of 2nd element's identifier
+ assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element
+ assertSame(tracker.get("1").get, payload(0))
+ assertSame(tracker.get("2").get, payload(1))
+ }
+
+ private def givenInput(func: DataOutputStream => Unit): DataInputStream = {
+ val buffer = new ByteArrayOutputStream()
+ val out = new DataOutputStream(buffer)
+ func(out)
+ new DataInputStream(new ByteArrayInputStream(buffer.toByteArray))
+ }
+
+ private def whenOutput = givenInput _
+
+ private def assertEndOfStream (in: DataInputStream): Unit = {
+ assertEquals(-1, in.read())
+ }
+}
diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala
new file mode 100644
index 000000000..b8e0122fd
--- /dev/null
+++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.util.dotnet
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK
+import org.junit.Assert.{assertEquals, assertThrows}
+import org.junit.Test
+
+@Test
+class UtilsTest {
+
+ @Test
+ def shouldIgnorePatchVersion(): Unit = {
+ val sparkVersion = "3.3.0"
+ val sparkMajorMinorVersionPrefix = "3.3"
+ val supportedSparkVersions = Set[String]("3.3.0")
+
+ Utils.validateSparkVersions(
+ true,
+ sparkVersion,
+ Utils.normalizeSparkVersion(sparkVersion),
+ sparkMajorMinorVersionPrefix,
+ supportedSparkVersions)
+ }
+
+ @Test
+ def shouldThrowForUnsupportedVersion(): Unit = {
+ val sparkVersion = "3.3.0"
+ val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion)
+ val sparkMajorMinorVersionPrefix = "3.3"
+ val supportedSparkVersions = Set[String]("3.3.0")
+
+ val exception = assertThrows(
+ classOf[IllegalArgumentException],
+ () => {
+ Utils.validateSparkVersions(
+ false,
+ sparkVersion,
+ normalizedSparkVersion,
+ sparkMajorMinorVersionPrefix,
+ supportedSparkVersions)
+ })
+
+ assertEquals(
+ s"Unsupported spark version used: '$sparkVersion'. " +
+ s"Normalized spark version used: '$normalizedSparkVersion'. " +
+ s"Supported versions: '${supportedSparkVersions.toSeq.sorted.mkString(", ")}'.",
+ exception.getMessage)
+ }
+
+ @Test
+ def shouldThrowForUnsupportedMajorMinorVersion(): Unit = {
+ val sparkVersion = "2.4.4"
+ val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion)
+ val sparkMajorMinorVersionPrefix = "3.3"
+ val supportedSparkVersions = Set[String]("3.3.0")
+
+ val exception = assertThrows(
+ classOf[IllegalArgumentException],
+ () => {
+ Utils.validateSparkVersions(
+ false,
+ sparkVersion,
+ normalizedSparkVersion,
+ sparkMajorMinorVersionPrefix,
+ supportedSparkVersions)
+ })
+
+ assertEquals(
+ s"Unsupported spark version used: '$sparkVersion'. " +
+ s"Normalized spark version used: '$normalizedSparkVersion'. " +
+ s"Supported spark major.minor version: '$sparkMajorMinorVersionPrefix'.",
+ exception.getMessage)
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/pom.xml b/src/scala/microsoft-spark-3-4/pom.xml
new file mode 100644
index 000000000..319c0c84d
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/pom.xml
@@ -0,0 +1,83 @@
+
+ 4.0.0
+
+ com.microsoft.scala
+ microsoft-spark
+ ${microsoft-spark.version}
+
+ microsoft-spark-3-4_2.12
+ 2019
+
+ UTF-8
+ 2.12.10
+ 2.12
+ 3.4.0
+
+
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-mllib_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ junit
+ junit
+ 4.13.1
+ test
+
+
+ org.specs
+ specs
+ 1.2.5
+ test
+
+
+
+
+ src/main/scala
+ src/test/scala
+
+
+ org.scala-tools
+ maven-scala-plugin
+ 2.15.2
+
+
+
+ compile
+ testCompile
+
+
+
+
+ ${scala.version}
+
+ -target:jvm-1.8
+ -deprecation
+ -feature
+
+
+
+
+
+
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala
new file mode 100644
index 000000000..aea355dfa
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.io.DataOutputStream
+
+import org.apache.spark.internal.Logging
+
+import scala.collection.mutable.Queue
+
+/**
+ * CallbackClient is used to communicate with the Dotnet CallbackServer.
+ * The client manages and maintains a pool of open CallbackConnections.
+ * Any callback request is delegated to a new CallbackConnection or
+ * unused CallbackConnection.
+ * @param address The address of the Dotnet CallbackServer
+ * @param port The port of the Dotnet CallbackServer
+ */
+class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging {
+ private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]()
+
+ private[this] var isShutdown: Boolean = false
+
+ final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit =
+ getOrCreateConnection() match {
+ case Some(connection) =>
+ try {
+ connection.send(callbackId, writeBody)
+ addConnection(connection)
+ } catch {
+ case e: Exception =>
+ logError(s"Error calling callback [callback id = $callbackId].", e)
+ connection.close()
+ throw e
+ }
+ case None => throw new Exception("Unable to get or create connection.")
+ }
+
+ private def getOrCreateConnection(): Option[CallbackConnection] = synchronized {
+ if (isShutdown) {
+ logInfo("Cannot get or create connection while client is shutdown.")
+ return None
+ }
+
+ if (connectionPool.nonEmpty) {
+ return Some(connectionPool.dequeue())
+ }
+
+ Some(new CallbackConnection(serDe, address, port))
+ }
+
+ private def addConnection(connection: CallbackConnection): Unit = synchronized {
+ assert(connection != null)
+ connectionPool.enqueue(connection)
+ }
+
+ def shutdown(): Unit = synchronized {
+ if (isShutdown) {
+ logInfo("Shutdown called, but already shutdown.")
+ return
+ }
+
+ logInfo("Shutting down.")
+ connectionPool.foreach(_.close)
+ connectionPool.clear
+ isShutdown = true
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala
new file mode 100644
index 000000000..604cf029b
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.io.{ByteArrayOutputStream, Closeable, DataInputStream, DataOutputStream}
+import java.net.Socket
+
+import org.apache.spark.internal.Logging
+
+/**
+ * CallbackConnection is used to process the callback communication
+ * between the JVM and Dotnet. It uses a TCP socket to communicate with
+ * the Dotnet CallbackServer and the socket is expected to be reused.
+ * @param address The address of the Dotnet CallbackServer
+ * @param port The port of the Dotnet CallbackServer
+ */
+class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging {
+ private[this] val socket: Socket = new Socket(address, port)
+ private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream)
+ private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream)
+
+ def send(
+ callbackId: Int,
+ writeBody: (DataOutputStream, SerDe) => Unit): Unit = {
+ logInfo(s"Calling callback [callback id = $callbackId] ...")
+
+ try {
+ serDe.writeInt(outputStream, CallbackFlags.CALLBACK)
+ serDe.writeInt(outputStream, callbackId)
+
+ val byteArrayOutputStream = new ByteArrayOutputStream()
+ writeBody(new DataOutputStream(byteArrayOutputStream), serDe)
+ serDe.writeInt(outputStream, byteArrayOutputStream.size)
+ byteArrayOutputStream.writeTo(outputStream);
+ } catch {
+ case e: Exception => {
+ throw new Exception("Error writing to stream.", e)
+ }
+ }
+
+ logInfo(s"Signaling END_OF_STREAM.")
+ try {
+ serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM)
+ outputStream.flush()
+
+ val endOfStreamResponse = readFlag(inputStream)
+ endOfStreamResponse match {
+ case CallbackFlags.END_OF_STREAM =>
+ logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.")
+ case _ => {
+ throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " +
+ s"Received: $endOfStreamResponse")
+ }
+ }
+ } catch {
+ case e: Exception => {
+ throw new Exception("Error while verifying end of stream.", e)
+ }
+ }
+ }
+
+ def close(): Unit = {
+ try {
+ serDe.writeInt(outputStream, CallbackFlags.CLOSE)
+ outputStream.flush()
+ } catch {
+ case e: Exception => logInfo("Unable to send close to .NET callback server.", e)
+ }
+
+ close(socket)
+ close(outputStream)
+ close(inputStream)
+ }
+
+ private def close(s: Socket): Unit = {
+ try {
+ assert(s != null)
+ s.close()
+ } catch {
+ case e: Exception => logInfo("Unable to close socket.", e)
+ }
+ }
+
+ private def close(c: Closeable): Unit = {
+ try {
+ assert(c != null)
+ c.close()
+ } catch {
+ case e: Exception => logInfo("Unable to close closeable.", e)
+ }
+ }
+
+ private def readFlag(inputStream: DataInputStream): Int = {
+ val callbackFlag = serDe.readInt(inputStream)
+ if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) {
+ val exceptionMessage = serDe.readString(inputStream)
+ throw new DotnetException(exceptionMessage)
+ }
+ callbackFlag
+ }
+
+ private object CallbackFlags {
+ val CLOSE: Int = -1
+ val CALLBACK: Int = -2
+ val DOTNET_EXCEPTION_THROWN: Int = -3
+ val END_OF_STREAM: Int = -4
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala
new file mode 100644
index 000000000..c6f528aee
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.net.InetSocketAddress
+import java.util.concurrent.TimeUnit
+import io.netty.bootstrap.ServerBootstrap
+import io.netty.channel.nio.NioEventLoopGroup
+import io.netty.channel.socket.SocketChannel
+import io.netty.channel.socket.nio.NioServerSocketChannel
+import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder
+import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_NUM_BACKEND_THREADS
+import org.apache.spark.{SparkConf, SparkEnv}
+
+/**
+ * Netty server that invokes JVM calls based upon receiving messages from .NET.
+ * The implementation mirrors the RBackend.
+ *
+ */
+class DotnetBackend extends Logging {
+ self => // for accessing the this reference in inner class(ChannelInitializer)
+ private[this] var channelFuture: ChannelFuture = _
+ private[this] var bootstrap: ServerBootstrap = _
+ private[this] var bossGroup: EventLoopGroup = _
+ private[this] val objectTracker = new JVMObjectTracker
+
+ @volatile
+ private[dotnet] var callbackClient: Option[CallbackClient] = None
+
+ def init(portNumber: Int): Int = {
+ val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
+ val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS)
+ logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.")
+ bossGroup = new NioEventLoopGroup(numBackendThreads)
+ val workerGroup = bossGroup
+
+ bootstrap = new ServerBootstrap()
+ .group(bossGroup, workerGroup)
+ .channel(classOf[NioServerSocketChannel])
+
+ bootstrap.childHandler(new ChannelInitializer[SocketChannel]() {
+ def initChannel(ch: SocketChannel): Unit = {
+ ch.pipeline()
+ .addLast("encoder", new ByteArrayEncoder())
+ .addLast(
+ "frameDecoder",
+ // maxFrameLength = 2G
+ // lengthFieldOffset = 0
+ // lengthFieldLength = 4
+ // lengthAdjustment = 0
+ // initialBytesToStrip = 4, i.e. strip out the length field itself
+ new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
+ .addLast("decoder", new ByteArrayDecoder())
+ .addLast("handler", new DotnetBackendHandler(self, objectTracker))
+ }
+ })
+
+ channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber))
+ channelFuture.syncUninterruptibly()
+ channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort
+ }
+
+ private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized {
+ callbackClient = callbackClient match {
+ case Some(_) => throw new Exception("Callback client already set.")
+ case None =>
+ logInfo(s"Connecting to a callback server at $address:$port")
+ Some(new CallbackClient(new SerDe(objectTracker), address, port))
+ }
+ }
+
+ private[dotnet] def shutdownCallbackClient(): Unit = synchronized {
+ callbackClient match {
+ case Some(client) => client.shutdown()
+ case None => logInfo("Callback server has already been shutdown.")
+ }
+ callbackClient = None
+ }
+
+ def run(): Unit = {
+ channelFuture.channel.closeFuture().syncUninterruptibly()
+ }
+
+ def close(): Unit = {
+ if (channelFuture != null) {
+ // close is a local operation and should finish within milliseconds; timeout just to be safe
+ channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS)
+ channelFuture = null
+ }
+ if (bootstrap != null && bootstrap.config().group() != null) {
+ bootstrap.config().group().shutdownGracefully()
+ }
+ if (bootstrap != null && bootstrap.config().childGroup() != null) {
+ bootstrap.config().childGroup().shutdownGracefully()
+ }
+ bootstrap = null
+
+ objectTracker.clear()
+
+ // Send close to .NET callback server.
+ shutdownCallbackClient()
+
+ // Shutdown the thread pool whose executors could still be running.
+ ThreadPool.shutdown()
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
new file mode 100644
index 000000000..fc3a5c911
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import scala.collection.mutable.HashMap
+import scala.language.existentials
+
+/**
+ * Handler for DotnetBackend.
+ * This implementation is similar to RBackendHandler.
+ */
+class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker)
+ extends SimpleChannelInboundHandler[Array[Byte]]
+ with Logging {
+
+ private[this] val serDe = new SerDe(objectsTracker)
+
+ override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
+ val reply = handleBackendRequest(msg)
+ ctx.write(reply)
+ }
+
+ override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {
+ ctx.flush()
+ }
+
+ def handleBackendRequest(msg: Array[Byte]): Array[Byte] = {
+ val bis = new ByteArrayInputStream(msg)
+ val dis = new DataInputStream(bis)
+
+ val bos = new ByteArrayOutputStream()
+ val dos = new DataOutputStream(bos)
+
+ // First bit is isStatic
+ val isStatic = serDe.readBoolean(dis)
+ val processId = serDe.readInt(dis)
+ val threadId = serDe.readInt(dis)
+ val objId = serDe.readString(dis)
+ val methodName = serDe.readString(dis)
+ val numArgs = serDe.readInt(dis)
+
+ if (objId == "DotnetHandler") {
+ methodName match {
+ case "stopBackend" =>
+ serDe.writeInt(dos, 0)
+ serDe.writeType(dos, "void")
+ server.close()
+ case "rm" =>
+ try {
+ val t = serDe.readObjectType(dis)
+ assert(t == 'c')
+ val objToRemove = serDe.readString(dis)
+ objectsTracker.remove(objToRemove)
+ serDe.writeInt(dos, 0)
+ serDe.writeObject(dos, null)
+ } catch {
+ case e: Exception =>
+ logError(s"Removing $objId failed", e)
+ serDe.writeInt(dos, -1)
+ }
+ case "rmThread" =>
+ try {
+ assert(serDe.readObjectType(dis) == 'i')
+ val processId = serDe.readInt(dis)
+ assert(serDe.readObjectType(dis) == 'i')
+ val threadToDelete = serDe.readInt(dis)
+ val result = ThreadPool.tryDeleteThread(processId, threadToDelete)
+ serDe.writeInt(dos, 0)
+ serDe.writeObject(dos, result.asInstanceOf[AnyRef])
+ } catch {
+ case e: Exception =>
+ logError(s"Removing thread $threadId failed", e)
+ serDe.writeInt(dos, -1)
+ }
+ case "connectCallback" =>
+ assert(serDe.readObjectType(dis) == 'c')
+ val address = serDe.readString(dis)
+ assert(serDe.readObjectType(dis) == 'i')
+ val port = serDe.readInt(dis)
+ server.setCallbackClient(address, port)
+ serDe.writeInt(dos, 0)
+
+ // Sends reference of CallbackClient to dotnet side,
+ // so that dotnet process can send the client back to Java side
+ // when calling any API containing callback functions.
+ serDe.writeObject(dos, server.callbackClient)
+ case "closeCallback" =>
+ logInfo("Requesting to close callback client")
+ server.shutdownCallbackClient()
+ serDe.writeInt(dos, 0)
+ serDe.writeType(dos, "void")
+ case _ => dos.writeInt(-1)
+ }
+ } else {
+ ThreadPool
+ .run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
+ }
+
+ bos.toByteArray
+ }
+
+ override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
+ // Skip logging the exception message if the connection was disconnected from
+ // the .NET side so that .NET side doesn't have to explicitly close the connection via
+ // "stopBackend." Note that an exception is still thrown if the exit status is non-zero,
+ // so skipping this kind of exception message does not affect the debugging.
+ if (!cause.getMessage.contains(
+ "An existing connection was forcibly closed by the remote host")) {
+ logError("Exception caught: ", cause)
+ }
+
+ // Close the connection when an exception is raised.
+ ctx.close()
+ }
+
+ def handleMethodCall(
+ isStatic: Boolean,
+ objId: String,
+ methodName: String,
+ numArgs: Int,
+ dis: DataInputStream,
+ dos: DataOutputStream): Unit = {
+ var obj: Object = null
+ var args: Array[java.lang.Object] = null
+ var methods: Array[java.lang.reflect.Method] = null
+
+ try {
+ val cls = if (isStatic) {
+ Utils.classForName(objId)
+ } else {
+ objectsTracker.get(objId) match {
+ case None => throw new IllegalArgumentException("Object not found " + objId)
+ case Some(o) =>
+ obj = o
+ o.getClass
+ }
+ }
+
+ args = readArgs(numArgs, dis)
+ methods = cls.getMethods
+
+ val selectedMethods = methods.filter(m => m.getName == methodName)
+ if (selectedMethods.length > 0) {
+ val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args)
+
+ if (index.isEmpty) {
+ logWarning(
+ s"cannot find matching method ${cls}.$methodName. "
+ + s"Candidates are:")
+ selectedMethods.foreach { method =>
+ logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})")
+ }
+ throw new Exception(s"No matched method found for $cls.$methodName")
+ }
+
+ val ret = selectedMethods(index.get).invoke(obj, args: _*)
+
+ // Write status bit
+ serDe.writeInt(dos, 0)
+ serDe.writeObject(dos, ret.asInstanceOf[AnyRef])
+ } else if (methodName == "") {
+ // methodName should be "" for constructor
+ val ctor = cls.getConstructors.filter { x =>
+ matchMethod(numArgs, args, x.getParameterTypes)
+ }.head
+
+ val obj = ctor.newInstance(args: _*)
+
+ serDe.writeInt(dos, 0)
+ serDe.writeObject(dos, obj.asInstanceOf[AnyRef])
+ } else {
+ throw new IllegalArgumentException(
+ "invalid method " + methodName + " for object " + objId)
+ }
+ } catch {
+ case e: Throwable =>
+ val jvmObj = objectsTracker.get(objId)
+ val jvmObjName = jvmObj match {
+ case Some(jObj) => jObj.getClass.getName
+ case None => "NullObject"
+ }
+ val argsStr = args
+ .map(arg => {
+ if (arg != null) {
+ s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
+ } else {
+ "[Value: NULL]"
+ }
+ })
+ .mkString(", ")
+
+ logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
+
+ if (methods != null) {
+ logDebug(s"All methods for $jvmObjName:")
+ methods.foreach(m => logDebug(m.toString))
+ }
+
+ serDe.writeInt(dos, -1)
+ serDe.writeString(dos, Utils.exceptionString(e.getCause))
+ }
+ }
+
+ // Read a number of arguments from the data input stream
+ def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
+ (0 until numArgs).map { arg =>
+ serDe.readObject(dis)
+ }.toArray
+ }
+
+ // Checks if the arguments passed in args matches the parameter types.
+ // NOTE: Currently we do exact match. We may add type conversions later.
+ def matchMethod(
+ numArgs: Int,
+ args: Array[java.lang.Object],
+ parameterTypes: Array[Class[_]]): Boolean = {
+ if (parameterTypes.length != numArgs) {
+ return false
+ }
+
+ for (i <- 0 until numArgs) {
+ val parameterType = parameterTypes(i)
+ var parameterWrapperType = parameterType
+
+ // Convert native parameters to Object types as args is Array[Object] here
+ if (parameterType.isPrimitive) {
+ parameterWrapperType = parameterType match {
+ case java.lang.Integer.TYPE => classOf[java.lang.Integer]
+ case java.lang.Long.TYPE => classOf[java.lang.Long]
+ case java.lang.Double.TYPE => classOf[java.lang.Double]
+ case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
+ case _ => parameterType
+ }
+ }
+
+ if (!parameterWrapperType.isInstance(args(i))) {
+ // non primitive types
+ if (!parameterType.isPrimitive && args(i) != null) {
+ return false
+ }
+
+ // primitive types
+ if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) {
+ return false
+ }
+ }
+ }
+
+ true
+ }
+
+ // Find a matching method signature in an array of signatures of constructors
+ // or methods of the same name according to the passed arguments. Arguments
+ // may be converted in order to match a signature.
+ //
+ // Note that in Java reflection, constructors and normal methods are of different
+ // classes, and share no parent class that provides methods for reflection uses.
+ // There is no unified way to handle them in this function. So an array of signatures
+ // is passed in instead of an array of candidate constructors or methods.
+ //
+ // Returns an Option[Int] which is the index of the matched signature in the array.
+ def findMatchedSignature(
+ parameterTypesOfMethods: Array[Array[Class[_]]],
+ args: Array[Object]): Option[Int] = {
+ val numArgs = args.length
+
+ for (index <- parameterTypesOfMethods.indices) {
+ val parameterTypes = parameterTypesOfMethods(index)
+
+ if (parameterTypes.length == numArgs) {
+ var argMatched = true
+ var i = 0
+ while (i < numArgs && argMatched) {
+ val parameterType = parameterTypes(i)
+
+ if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) {
+ // The case that the parameter type is a Scala Seq and the argument
+ // is a Java array is considered matching. The array will be converted
+ // to a Seq later if this method is matched.
+ } else {
+ var parameterWrapperType = parameterType
+
+ // Convert native parameters to Object types as args is Array[Object] here
+ if (parameterType.isPrimitive) {
+ parameterWrapperType = parameterType match {
+ case java.lang.Integer.TYPE => classOf[java.lang.Integer]
+ case java.lang.Long.TYPE => classOf[java.lang.Long]
+ case java.lang.Double.TYPE => classOf[java.lang.Double]
+ case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
+ case _ => parameterType
+ }
+ }
+ if ((parameterType.isPrimitive || args(i) != null) &&
+ !parameterWrapperType.isInstance(args(i))) {
+ argMatched = false
+ }
+ }
+
+ i = i + 1
+ }
+
+ if (argMatched) {
+ // For now, we return the first matching method.
+ // TODO: find best method in matching methods.
+
+ // Convert args if needed
+ val parameterTypes = parameterTypesOfMethods(index)
+
+ for (i <- 0 until numArgs) {
+ if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) {
+ // Convert a Java array to scala Seq
+ args(i) = args(i).asInstanceOf[Array[_]].toSeq
+ }
+ }
+
+ return Some(index)
+ }
+ }
+ }
+ None
+ }
+
+ def logError(id: String, e: Exception): Unit = {}
+}
+
+
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala
new file mode 100644
index 000000000..c70d16b03
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala
@@ -0,0 +1,13 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+class DotnetException(message: String, cause: Throwable)
+ extends Exception(message, cause) {
+
+ def this(message: String) = this(message, null)
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala
new file mode 100644
index 000000000..f5277c215
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import org.apache.spark.SparkContext
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python._
+import org.apache.spark.rdd.RDD
+
+object DotnetRDD {
+ def createPythonRDD(
+ parent: RDD[_],
+ func: PythonFunction,
+ preservePartitoning: Boolean): PythonRDD = {
+ new PythonRDD(parent, func, preservePartitoning)
+ }
+
+ def createJavaRDDFromArray(
+ sc: SparkContext,
+ arr: Array[Array[Byte]],
+ numSlices: Int): JavaRDD[Array[Byte]] = {
+ JavaRDD.fromRDD(sc.parallelize(arr, numSlices))
+ }
+
+ def toJavaRDD(rdd: RDD[_]): JavaRDD[_] = JavaRDD.fromRDD(rdd)
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
new file mode 100644
index 000000000..9f556338b
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import scala.collection.JavaConverters._
+
+/** DotnetUtils object that hosts some helper functions
+ * help data type conversions between dotnet and scala
+ */
+object DotnetUtils {
+
+ /** A helper function to convert scala Map to java.util.Map
+ * @param value - scala Map
+ * @return java.util.Map
+ */
+ def convertToJavaMap(value: Map[_, _]): java.util.Map[_, _] = value.asJava
+
+ /** Convert java data type to corresponding scala type
+ * @param value - java.lang.Object
+ * @return Any
+ */
+ def mapScalaToJava(value: java.lang.Object): Any = {
+ value match {
+ case i: java.lang.Integer => i.toInt
+ case d: java.lang.Double => d.toDouble
+ case f: java.lang.Float => f.toFloat
+ case b: java.lang.Boolean => b.booleanValue()
+ case l: java.lang.Long => l.toLong
+ case s: java.lang.Short => s.toShort
+ case by: java.lang.Byte => by.toByte
+ case c: java.lang.Character => c.toChar
+ case _ => value
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala
new file mode 100644
index 000000000..81cfaf88b
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+
+package org.apache.spark.api.dotnet
+
+import scala.collection.mutable.HashMap
+
+/**
+ * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects.
+ */
+private[dotnet] class JVMObjectTracker {
+
+ // Multiple threads may access objMap and increase objCounter. Because get method return Option,
+ // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap.
+ private[this] val objMap = new HashMap[String, Object]
+ private[this] var objCounter: Int = 1
+
+ def getObject(id: String): Object = {
+ synchronized {
+ objMap(id)
+ }
+ }
+
+ def get(id: String): Option[Object] = {
+ synchronized {
+ objMap.get(id)
+ }
+ }
+
+ def put(obj: Object): String = {
+ synchronized {
+ val objId = objCounter.toString
+ objCounter = objCounter + 1
+ objMap.put(objId, obj)
+ objId
+ }
+ }
+
+ def remove(id: String): Option[Object] = {
+ synchronized {
+ objMap.remove(id)
+ }
+ }
+
+ def clear(): Unit = {
+ synchronized {
+ objMap.clear()
+ objCounter = 1
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala
new file mode 100644
index 000000000..06a476f67
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.sql.api.dotnet
+
+import org.apache.spark.SparkConf
+
+/*
+ * Utils for JvmBridge.
+ */
+object JvmBridgeUtils {
+ def getKeyValuePairAsString(kvp: (String, String)): String = {
+ return kvp._1 + "=" + kvp._2
+ }
+
+ def getKeyValuePairArrayAsString(kvpArray: Array[(String, String)]): String = {
+ val sb = new StringBuilder
+
+ for (kvp <- kvpArray) {
+ sb.append(getKeyValuePairAsString(kvp))
+ sb.append(";")
+ }
+
+ sb.toString
+ }
+
+ def getSparkConfAsString(sparkConf: SparkConf): String = {
+ getKeyValuePairArrayAsString(sparkConf.getAll)
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
new file mode 100644
index 000000000..a3df3788a
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala
@@ -0,0 +1,387 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.io.{DataInputStream, DataOutputStream}
+import java.nio.charset.StandardCharsets
+import java.sql.{Date, Time, Timestamp}
+
+import org.apache.spark.sql.Row
+
+import scala.collection.JavaConverters._
+
+/**
+ * Class responsible for serialization and deserialization between CLR & JVM.
+ * This implementation of methods is mostly identical to the SerDe implementation in R.
+ */
+class SerDe(val tracker: JVMObjectTracker) {
+
+ def readObjectType(dis: DataInputStream): Char = {
+ dis.readByte().toChar
+ }
+
+ def readObject(dis: DataInputStream): Object = {
+ val dataType = readObjectType(dis)
+ readTypedObject(dis, dataType)
+ }
+
+ private def readTypedObject(dis: DataInputStream, dataType: Char): Object = {
+ dataType match {
+ case 'n' => null
+ case 'i' => new java.lang.Integer(readInt(dis))
+ case 'g' => new java.lang.Long(readLong(dis))
+ case 'd' => new java.lang.Double(readDouble(dis))
+ case 'b' => new java.lang.Boolean(readBoolean(dis))
+ case 'c' => readString(dis)
+ case 'e' => readMap(dis)
+ case 'r' => readBytes(dis)
+ case 'l' => readList(dis)
+ case 'D' => readDate(dis)
+ case 't' => readTime(dis)
+ case 'j' => tracker.getObject(readString(dis))
+ case 'R' => readRowArr(dis)
+ case 'O' => readObjectArr(dis)
+ case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
+ }
+ }
+
+ private def readBytes(in: DataInputStream): Array[Byte] = {
+ val len = readInt(in)
+ val out = new Array[Byte](len)
+ in.readFully(out)
+ out
+ }
+
+ def readInt(in: DataInputStream): Int = {
+ in.readInt()
+ }
+
+ private def readLong(in: DataInputStream): Long = {
+ in.readLong()
+ }
+
+ private def readDouble(in: DataInputStream): Double = {
+ in.readDouble()
+ }
+
+ private def readStringBytes(in: DataInputStream, len: Int): String = {
+ val bytes = new Array[Byte](len)
+ in.readFully(bytes)
+ val str = new String(bytes, "UTF-8")
+ str
+ }
+
+ def readString(in: DataInputStream): String = {
+ val len = in.readInt()
+ readStringBytes(in, len)
+ }
+
+ def readBoolean(in: DataInputStream): Boolean = {
+ in.readBoolean()
+ }
+
+ private def readDate(in: DataInputStream): Date = {
+ Date.valueOf(readString(in))
+ }
+
+ private def readTime(in: DataInputStream): Timestamp = {
+ val seconds = in.readDouble()
+ val sec = Math.floor(seconds).toLong
+ val t = new Timestamp(sec * 1000L)
+ t.setNanos(((seconds - sec) * 1e9).toInt)
+ t
+ }
+
+ private def readRow(in: DataInputStream): Row = {
+ val len = readInt(in)
+ Row.fromSeq((0 until len).map(_ => readObject(in)))
+ }
+
+ private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readBytes(in)).toArray
+ }
+
+ private def readIntArr(in: DataInputStream): Array[Int] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readInt(in)).toArray
+ }
+
+ private def readLongArr(in: DataInputStream): Array[Long] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readLong(in)).toArray
+ }
+
+ private def readDoubleArr(in: DataInputStream): Array[Double] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readDouble(in)).toArray
+ }
+
+ private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readDoubleArr(in)).toArray
+ }
+
+ private def readBooleanArr(in: DataInputStream): Array[Boolean] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readBoolean(in)).toArray
+ }
+
+ private def readStringArr(in: DataInputStream): Array[String] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readString(in)).toArray
+ }
+
+ private def readRowArr(in: DataInputStream): java.util.List[Row] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readRow(in)).toList.asJava
+ }
+
+ private def readObjectArr(in: DataInputStream): Seq[Any] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readObject(in))
+ }
+
+ private def readList(dis: DataInputStream): Array[_] = {
+ val arrType = readObjectType(dis)
+ arrType match {
+ case 'i' => readIntArr(dis)
+ case 'g' => readLongArr(dis)
+ case 'c' => readStringArr(dis)
+ case 'd' => readDoubleArr(dis)
+ case 'A' => readDoubleArrArr(dis)
+ case 'b' => readBooleanArr(dis)
+ case 'j' => readStringArr(dis).map(x => tracker.getObject(x))
+ case 'r' => readBytesArr(dis)
+ case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
+ }
+ }
+
+ private def readMap(in: DataInputStream): java.util.Map[Object, Object] = {
+ val len = readInt(in)
+ if (len > 0) {
+ val keysType = readObjectType(in)
+ val keysLen = readInt(in)
+ val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType))
+
+ val valuesLen = readInt(in)
+ val values = (0 until valuesLen).map(_ => {
+ val valueType = readObjectType(in)
+ readTypedObject(in, valueType)
+ })
+ keys.zip(values).toMap.asJava
+ } else {
+ new java.util.HashMap[Object, Object]()
+ }
+ }
+
+ // Using the same mapping as SparkR implementation for now
+ // Methods to write out data from Java to .NET.
+ //
+ // Type mapping from Java to .NET:
+ //
+ // void -> NULL
+ // Int -> integer
+ // String -> character
+ // Boolean -> logical
+ // Float -> double
+ // Double -> double
+ // Long -> long
+ // Array[Byte] -> raw
+ // Date -> Date
+ // Time -> POSIXct
+ //
+ // Array[T] -> list()
+ // Object -> jobj
+
+ def writeType(dos: DataOutputStream, typeStr: String): Unit = {
+ typeStr match {
+ case "void" => dos.writeByte('n')
+ case "character" => dos.writeByte('c')
+ case "double" => dos.writeByte('d')
+ case "doublearray" => dos.writeByte('A')
+ case "long" => dos.writeByte('g')
+ case "integer" => dos.writeByte('i')
+ case "logical" => dos.writeByte('b')
+ case "date" => dos.writeByte('D')
+ case "time" => dos.writeByte('t')
+ case "raw" => dos.writeByte('r')
+ case "list" => dos.writeByte('l')
+ case "jobj" => dos.writeByte('j')
+ case _ => throw new IllegalArgumentException(s"Invalid type $typeStr")
+ }
+ }
+
+ def writeObject(dos: DataOutputStream, value: Object): Unit = {
+ if (value == null || value == Unit) {
+ writeType(dos, "void")
+ } else {
+ value.getClass.getName match {
+ case "java.lang.String" =>
+ writeType(dos, "character")
+ writeString(dos, value.asInstanceOf[String])
+ case "float" | "java.lang.Float" =>
+ writeType(dos, "double")
+ writeDouble(dos, value.asInstanceOf[Float].toDouble)
+ case "double" | "java.lang.Double" =>
+ writeType(dos, "double")
+ writeDouble(dos, value.asInstanceOf[Double])
+ case "long" | "java.lang.Long" =>
+ writeType(dos, "long")
+ writeLong(dos, value.asInstanceOf[Long])
+ case "int" | "java.lang.Integer" =>
+ writeType(dos, "integer")
+ writeInt(dos, value.asInstanceOf[Int])
+ case "boolean" | "java.lang.Boolean" =>
+ writeType(dos, "logical")
+ writeBoolean(dos, value.asInstanceOf[Boolean])
+ case "java.sql.Date" =>
+ writeType(dos, "date")
+ writeDate(dos, value.asInstanceOf[Date])
+ case "java.sql.Time" =>
+ writeType(dos, "time")
+ writeTime(dos, value.asInstanceOf[Time])
+ case "java.sql.Timestamp" =>
+ writeType(dos, "time")
+ writeTime(dos, value.asInstanceOf[Timestamp])
+ case "[B" =>
+ writeType(dos, "raw")
+ writeBytes(dos, value.asInstanceOf[Array[Byte]])
+ // TODO: Types not handled right now include
+ // byte, char, short, float
+
+ // Handle arrays
+ case "[Ljava.lang.String;" =>
+ writeType(dos, "list")
+ writeStringArr(dos, value.asInstanceOf[Array[String]])
+ case "[I" =>
+ writeType(dos, "list")
+ writeIntArr(dos, value.asInstanceOf[Array[Int]])
+ case "[J" =>
+ writeType(dos, "list")
+ writeLongArr(dos, value.asInstanceOf[Array[Long]])
+ case "[D" =>
+ writeType(dos, "list")
+ writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
+ case "[[D" =>
+ writeType(dos, "list")
+ writeDoubleArrArr(dos, value.asInstanceOf[Array[Array[Double]]])
+ case "[Z" =>
+ writeType(dos, "list")
+ writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
+ case "[[B" =>
+ writeType(dos, "list")
+ writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]])
+ case otherName =>
+ // Handle array of objects
+ if (otherName.startsWith("[L")) {
+ val objArr = value.asInstanceOf[Array[Object]]
+ writeType(dos, "list")
+ writeType(dos, "jobj")
+ dos.writeInt(objArr.length)
+ objArr.foreach(o => writeJObj(dos, o))
+ } else {
+ writeType(dos, "jobj")
+ writeJObj(dos, value)
+ }
+ }
+ }
+ }
+
+ def writeInt(out: DataOutputStream, value: Int): Unit = {
+ out.writeInt(value)
+ }
+
+ def writeLong(out: DataOutputStream, value: Long): Unit = {
+ out.writeLong(value)
+ }
+
+ private def writeDouble(out: DataOutputStream, value: Double): Unit = {
+ out.writeDouble(value)
+ }
+
+ private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = {
+ out.writeBoolean(value)
+ }
+
+ private def writeDate(out: DataOutputStream, value: Date): Unit = {
+ writeString(out, value.toString)
+ }
+
+ private def writeTime(out: DataOutputStream, value: Time): Unit = {
+ out.writeDouble(value.getTime.toDouble / 1000.0)
+ }
+
+ private def writeTime(out: DataOutputStream, value: Timestamp): Unit = {
+ out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9)
+ }
+
+ def writeString(out: DataOutputStream, value: String): Unit = {
+ val utf8 = value.getBytes(StandardCharsets.UTF_8)
+ val len = utf8.length
+ out.writeInt(len)
+ out.write(utf8, 0, len)
+ }
+
+ private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = {
+ out.writeInt(value.length)
+ out.write(value)
+ }
+
+ def writeJObj(out: DataOutputStream, value: Object): Unit = {
+ val objId = tracker.put(value)
+ writeString(out, objId)
+ }
+
+ private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = {
+ writeType(out, "integer")
+ out.writeInt(value.length)
+ value.foreach(v => out.writeInt(v))
+ }
+
+ private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = {
+ writeType(out, "long")
+ out.writeInt(value.length)
+ value.foreach(v => out.writeLong(v))
+ }
+
+ private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = {
+ writeType(out, "double")
+ out.writeInt(value.length)
+ value.foreach(v => out.writeDouble(v))
+ }
+
+ private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = {
+ writeType(out, "doublearray")
+ out.writeInt(value.length)
+ value.foreach(v => writeDoubleArr(out, v))
+ }
+
+ private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = {
+ writeType(out, "logical")
+ out.writeInt(value.length)
+ value.foreach(v => writeBoolean(out, v))
+ }
+
+ private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = {
+ writeType(out, "character")
+ out.writeInt(value.length)
+ value.foreach(v => writeString(out, v))
+ }
+
+ private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = {
+ writeType(out, "raw")
+ out.writeInt(value.length)
+ value.foreach(v => writeBytes(out, v))
+ }
+}
+
+private object SerializationFormats {
+ val BYTE = "byte"
+ val STRING = "string"
+ val ROW = "row"
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala
new file mode 100644
index 000000000..50551a7d9
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import java.util.concurrent.{ExecutorService, Executors}
+
+import scala.collection.mutable
+
+/**
+ * Pool of thread executors. There should be a 1-1 correspondence between C# threads
+ * and Java threads.
+ */
+object ThreadPool {
+
+ /**
+ * Map from (processId, threadId) to corresponding executor.
+ */
+ private val executors: mutable.HashMap[(Int, Int), ExecutorService] =
+ new mutable.HashMap[(Int, Int), ExecutorService]()
+
+ /**
+ * Run some code on a particular thread.
+ * @param processId Integer id of the process.
+ * @param threadId Integer id of the thread.
+ * @param task Function to run on the thread.
+ */
+ def run(processId: Int, threadId: Int, task: () => Unit): Unit = {
+ val executor = getOrCreateExecutor(processId, threadId)
+ val future = executor.submit(new Runnable {
+ override def run(): Unit = task()
+ })
+
+ future.get()
+ }
+
+ /**
+ * Try to delete a particular thread.
+ * @param processId Integer id of the process.
+ * @param threadId Integer id of the thread.
+ * @return True if successful, false if thread does not exist.
+ */
+ def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized {
+ executors.remove((processId, threadId)) match {
+ case Some(executorService) =>
+ executorService.shutdown()
+ true
+ case None => false
+ }
+ }
+
+ /**
+ * Shutdown any running ExecutorServices.
+ */
+ def shutdown(): Unit = synchronized {
+ executors.foreach(_._2.shutdown())
+ executors.clear()
+ }
+
+ /**
+ * Get the executor if it exists, otherwise create a new one.
+ * @param processId Integer id of the process.
+ * @param threadId Integer id of the thread.
+ * @return The new or existing executor with the given id.
+ */
+ private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized {
+ executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor)
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala
new file mode 100644
index 000000000..4551a70bd
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.deploy.dotnet
+
+import org.apache.spark.SparkException
+
+/**
+ * This exception type describes an exception thrown by a .NET user application.
+ *
+ * @param exitCode Exit code returned by the .NET application.
+ * @param dotNetStackTrace Stacktrace extracted from .NET application logs.
+ */
+private[spark] class DotNetUserAppException(exitCode: Int, dotNetStackTrace: Option[String])
+ extends SparkException(
+ dotNetStackTrace match {
+ case None => s"User application exited with $exitCode"
+ case Some(e) => s"User application exited with $exitCode and .NET exception: $e"
+ })
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala
new file mode 100644
index 000000000..2d1c2cf3b
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.deploy.dotnet
+
+import java.io.File
+import java.net.URI
+import java.nio.file.attribute.PosixFilePermissions
+import java.nio.file.{FileSystems, Files, Paths}
+import java.util.Locale
+import java.util.concurrent.{Semaphore, TimeUnit}
+
+import org.apache.commons.io.FilenameUtils
+import org.apache.commons.io.output.TeeOutputStream
+import org.apache.hadoop.fs.Path
+import org.apache.spark
+import org.apache.spark.api.dotnet.DotnetBackend
+import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.dotnet.Dotnet.{
+ DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK,
+ ERROR_BUFFER_SIZE, ERROR_REDIRECITON_ENABLED
+}
+import org.apache.spark.util.dotnet.{Utils => DotnetUtils}
+import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils}
+import org.apache.spark.{SecurityManager, SparkConf, SparkUserAppException}
+
+import scala.collection.JavaConverters._
+import scala.io.StdIn
+import scala.util.Try
+
+/**
+ * DotnetRunner class used to launch Spark .NET applications using spark-submit.
+ * It executes .NET application as a subprocess and then has it connect back to
+ * the JVM to access system properties etc.
+ */
+object DotnetRunner extends Logging {
+ private val DEBUG_PORT = 5567
+ private val supportedSparkMajorMinorVersionPrefix = "3.4"
+ private val supportedSparkVersions = Set[String]("3.2.0", "3.2.1", "3.2.2", "3.2.3", "3.3.0", "3.3.1", "3.3.2", "3.4.0")
+
+ val SPARK_VERSION = DotnetUtils.normalizeSparkVersion(spark.SPARK_VERSION)
+
+ def main(args: Array[String]): Unit = {
+ if (args.length == 0) {
+ throw new IllegalArgumentException("At least one argument is expected.")
+ }
+
+ DotnetUtils.validateSparkVersions(
+ sys.props
+ .getOrElse(
+ DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.key,
+ DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.defaultValue.get.toString)
+ .toBoolean,
+ spark.SPARK_VERSION,
+ SPARK_VERSION,
+ supportedSparkMajorMinorVersionPrefix,
+ supportedSparkVersions)
+
+ val settings = initializeSettings(args)
+
+ // Determines if this needs to be run in debug mode.
+ // In debug mode this runner will not launch a .NET process.
+ val runInDebugMode = settings._1
+ @volatile var dotnetBackendPortNumber = settings._2
+ var dotnetExecutable = ""
+ var otherArgs: Array[String] = null
+
+ if (!runInDebugMode) {
+ if (args(0).toLowerCase(Locale.ROOT).endsWith(".zip")) {
+ var zipFileName = args(0)
+ val zipFileUri = Try(new URI(zipFileName)).getOrElse(new File(zipFileName).toURI)
+ val workingDir = new File("").getAbsoluteFile
+ val driverDir = new File(workingDir, FilenameUtils.getBaseName(zipFileUri.getPath()))
+
+ // Standalone cluster mode where .NET application is remotely located.
+ if (zipFileUri.getScheme() != "file") {
+ zipFileName = downloadDriverFile(zipFileName, workingDir.getAbsolutePath).getName
+ }
+
+ logInfo(s"Unzipping .NET driver $zipFileName to $driverDir")
+ DotnetUtils.unzip(new File(zipFileName), driverDir)
+
+ // Reuse windows-specific formatting in PythonRunner.
+ dotnetExecutable = PythonRunner.formatPath(resolveDotnetExecutable(driverDir, args(1)))
+ otherArgs = args.slice(2, args.length)
+ } else {
+ // Reuse windows-specific formatting in PythonRunner.
+ dotnetExecutable = PythonRunner.formatPath(args(0))
+ otherArgs = args.slice(1, args.length)
+ }
+ } else {
+ otherArgs = args.slice(1, args.length)
+ }
+
+ val processParameters = new java.util.ArrayList[String]
+ processParameters.add(dotnetExecutable)
+ otherArgs.foreach(arg => processParameters.add(arg))
+
+ logInfo(s"Starting DotnetBackend with $dotnetExecutable.")
+
+ // Time to wait for DotnetBackend to initialize in seconds.
+ val backendTimeout = sys.env.getOrElse("DOTNETBACKEND_TIMEOUT", "120").toInt
+
+ // Launch a DotnetBackend server for the .NET process to connect to; this will let it see our
+ // Java system properties etc.
+ val dotnetBackend = new DotnetBackend()
+ val initialized = new Semaphore(0)
+ val dotnetBackendThread = new Thread("DotnetBackend") {
+ override def run() {
+ // need to get back dotnetBackendPortNumber because if the value passed to init is 0
+ // the port number is dynamically assigned in the backend
+ dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber)
+ logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber")
+ initialized.release()
+ dotnetBackend.run()
+ }
+ }
+
+ dotnetBackendThread.start()
+
+ if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) {
+ if (!runInDebugMode) {
+ var returnCode = -1
+ var process: Process = null
+ val enableLogRedirection: Boolean = sys.props
+ .getOrElse(
+ ERROR_REDIRECITON_ENABLED.key,
+ ERROR_REDIRECITON_ENABLED.defaultValue.get.toString).toBoolean
+ val stderrBuffer: Option[CircularBuffer] = Option(enableLogRedirection).collect {
+ case true => new CircularBuffer(
+ sys.props.getOrElse(
+ ERROR_BUFFER_SIZE.key,
+ ERROR_BUFFER_SIZE.defaultValue.get.toString).toInt)
+ }
+
+ try {
+ val builder = new ProcessBuilder(processParameters)
+ val env = builder.environment()
+ env.put("DOTNETBACKEND_PORT", dotnetBackendPortNumber.toString)
+
+ for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) {
+ env.put(key, value)
+ logInfo(s"Adding key=$key and value=$value to environment")
+ }
+ builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
+ process = builder.start()
+
+ // Redirect stdin of JVM process to stdin of .NET process.
+ new RedirectThread(System.in, process.getOutputStream, "redirect JVM input").start()
+ // Redirect stdout and stderr of .NET process to System.out and to buffer
+ // if log direction is enabled. If not, redirect only to System.out.
+ new RedirectThread(
+ process.getInputStream,
+ stderrBuffer match {
+ case Some(buffer) => new TeeOutputStream(System.out, buffer)
+ case _ => System.out
+ },
+ "redirect .NET stdout and stderr").start()
+
+ process.waitFor()
+ } catch {
+ case t: Throwable =>
+ logThrowable(t)
+ } finally {
+ returnCode = closeDotnetProcess(process)
+ closeBackend(dotnetBackend)
+ }
+ if (returnCode != 0) {
+ if (stderrBuffer.isDefined) {
+ throw new DotNetUserAppException(returnCode, Some(stderrBuffer.get.toString))
+ } else {
+ throw new SparkUserAppException(returnCode)
+ }
+ } else {
+ logInfo(s".NET application exited successfully")
+ }
+ // TODO: The following is causing the following error:
+ // INFO ApplicationMaster: Final app status: FAILED, exitCode: 16,
+ // (reason: Shutdown hook called before final status was reported.)
+ // DotnetUtils.exit(returnCode)
+ } else {
+ // scalastyle:off println
+ println("***********************************************************************")
+ println("* .NET Backend running debug mode. Press enter to exit *")
+ println("***********************************************************************")
+ // scalastyle:on println
+
+ StdIn.readLine()
+ closeBackend(dotnetBackend)
+ DotnetUtils.exit(0)
+ }
+ } else {
+ logError(s"DotnetBackend did not initialize in $backendTimeout seconds")
+ DotnetUtils.exit(-1)
+ }
+ }
+
+ // When the executable is downloaded as part of zip file, check if the file exists
+ // after zip file is unzipped under the given dir. Once it is found, change the
+ // permission to executable (only for Unix systems, since the zip file may have been
+ // created under Windows. Finally, the absolute path for the executable is returned.
+ private def resolveDotnetExecutable(dir: File, dotnetExecutable: String): String = {
+ val path = Paths.get(dir.getAbsolutePath, dotnetExecutable)
+ val resolvedExecutable = if (Files.isRegularFile(path)) {
+ path.toAbsolutePath.toString
+ } else {
+ Files
+ .walk(FileSystems.getDefault.getPath(dir.getAbsolutePath))
+ .iterator()
+ .asScala
+ .find(path => Files.isRegularFile(path) && path.getFileName.toString == dotnetExecutable) match {
+ case Some(path) => path.toAbsolutePath.toString
+ case None =>
+ throw new IllegalArgumentException(
+ s"Failed to find $dotnetExecutable under ${dir.getAbsolutePath}")
+ }
+ }
+
+ if (DotnetUtils.supportPosix) {
+ Files.setPosixFilePermissions(
+ Paths.get(resolvedExecutable),
+ PosixFilePermissions.fromString("rwxr-xr-x"))
+ }
+
+ resolvedExecutable
+ }
+
+ /**
+ * Download HDFS file into the supplied directory and return its local path.
+ * Will throw an exception if there are errors during downloading.
+ */
+ private def downloadDriverFile(hdfsFilePath: String, driverDir: String): File = {
+ val sparkConf = new SparkConf()
+ val filePath = new Path(hdfsFilePath)
+
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf)
+ val jarFileName = filePath.getName
+ val localFile = new File(driverDir, jarFileName)
+
+ if (!localFile.exists()) { // May already exist if running multiple workers on one node
+ logInfo(s"Copying user file $filePath to $driverDir")
+ Utils.fetchFile(
+ hdfsFilePath,
+ new File(driverDir),
+ sparkConf,
+ hadoopConf,
+ System.currentTimeMillis(),
+ useCache = false)
+ }
+
+ if (!localFile.exists()) {
+ throw new Exception(s"Did not see expected $jarFileName in $driverDir")
+ }
+
+ localFile
+ }
+
+ private def closeBackend(dotnetBackend: DotnetBackend): Unit = {
+ logInfo("Closing DotnetBackend")
+ dotnetBackend.close()
+ }
+
+ private def closeDotnetProcess(dotnetProcess: Process): Int = {
+ if (dotnetProcess == null) {
+ return -1
+ } else if (!dotnetProcess.isAlive) {
+ return dotnetProcess.exitValue()
+ }
+
+ // Try to (gracefully on Linux) kill the process and resort to force if interrupted
+ var returnCode = -1
+ logInfo("Closing .NET process")
+ try {
+ dotnetProcess.destroy()
+ returnCode = dotnetProcess.waitFor()
+ } catch {
+ case _: InterruptedException =>
+ logInfo(
+ "Thread interrupted while waiting for graceful close. Forcefully closing .NET process")
+ returnCode = dotnetProcess.destroyForcibly().waitFor()
+ case t: Throwable =>
+ logThrowable(t)
+ }
+
+ returnCode
+ }
+
+ private def initializeSettings(args: Array[String]): (Boolean, Int) = {
+ val runInDebugMode = (args.length == 1 || args.length == 2) && args(0).equalsIgnoreCase(
+ "debug")
+ var portNumber = 0
+ if (runInDebugMode) {
+ if (args.length == 1) {
+ portNumber = DEBUG_PORT
+ } else if (args.length == 2) {
+ portNumber = Integer.parseInt(args(1))
+ }
+ }
+
+ (runInDebugMode, portNumber)
+ }
+
+ private def logThrowable(throwable: Throwable): Unit =
+ logError(s"${throwable.getMessage} \n ${throwable.getStackTrace.mkString("\n")}")
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala
new file mode 100644
index 000000000..18ba4c6e5
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.internal.config.dotnet
+
+import org.apache.spark.internal.config.ConfigBuilder
+
+private[spark] object Dotnet {
+ val DOTNET_NUM_BACKEND_THREADS = ConfigBuilder("spark.dotnet.numDotnetBackendThreads").intConf
+ .createWithDefault(10)
+
+ val DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK =
+ ConfigBuilder("spark.dotnet.ignoreSparkPatchVersionCheck").booleanConf
+ .createWithDefault(false)
+
+ val ERROR_REDIRECITON_ENABLED =
+ ConfigBuilder("spark.nonjvm.error.forwarding.enabled").booleanConf
+ .createWithDefault(false)
+
+ val ERROR_BUFFER_SIZE =
+ ConfigBuilder("spark.nonjvm.error.buffer.size")
+ .intConf
+ .checkValue(_ >= 0, "The error buffer size must not be negative")
+ .createWithDefault(10240)
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
new file mode 100644
index 000000000..3e3c3e0e3
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
@@ -0,0 +1,26 @@
+
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.mllib.api.dotnet
+
+import org.apache.spark.ml._
+import scala.collection.JavaConverters._
+
+/** MLUtils object that hosts helper functions
+ * related to ML usage
+ */
+object MLUtils {
+
+ /** A helper function to let pipeline accept java.util.ArrayList
+ * format stages in scala code
+ * @param pipeline - The pipeline to be set stages
+ * @param value - A java.util.ArrayList of PipelineStages to be set as stages
+ * @return The pipeline
+ */
+ def setPipelineStages(pipeline: Pipeline, value: java.util.ArrayList[_ <: PipelineStage]): Pipeline =
+ pipeline.setStages(value.asScala.toArray)
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala
new file mode 100644
index 000000000..5d06d4304
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.sql.api.dotnet
+
+import org.apache.spark.api.dotnet.CallbackClient
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.streaming.DataStreamWriter
+
+class DotnetForeachBatchFunction(callbackClient: CallbackClient, callbackId: Int) extends Logging {
+ def call(batchDF: DataFrame, batchId: Long): Unit =
+ callbackClient.send(
+ callbackId,
+ (dos, serDe) => {
+ serDe.writeJObj(dos, batchDF)
+ serDe.writeLong(dos, batchId)
+ })
+}
+
+object DotnetForeachBatchHelper {
+ def callForeachBatch(client: Option[CallbackClient], dsw: DataStreamWriter[Row], callbackId: Int): Unit = {
+ val dotnetForeachFunc = client match {
+ case Some(value) => new DotnetForeachBatchFunction(value, callbackId)
+ case None => throw new Exception("CallbackClient is null.")
+ }
+
+ dsw.foreachBatch(dotnetForeachFunc.call _)
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala
new file mode 100644
index 000000000..b5e97289a
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.sql.api.dotnet
+
+import java.util.{List => JList, Map => JMap}
+
+import org.apache.spark.api.python.{PythonAccumulatorV2, PythonBroadcast, PythonFunction}
+import org.apache.spark.broadcast.Broadcast
+
+object SQLUtils {
+
+ /**
+ * Exposes createPythonFunction to the .NET client to enable registering UDFs.
+ */
+ def createPythonFunction(
+ command: Array[Byte],
+ envVars: JMap[String, String],
+ pythonIncludes: JList[String],
+ pythonExec: String,
+ pythonVersion: String,
+ broadcastVars: JList[Broadcast[PythonBroadcast]],
+ accumulator: PythonAccumulatorV2): PythonFunction = {
+
+ PythonFunction(
+ command,
+ envVars,
+ pythonIncludes,
+ pythonExec,
+ pythonVersion,
+ broadcastVars,
+ accumulator)
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/test/TestUtils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/test/TestUtils.scala
new file mode 100644
index 000000000..1cd45aa95
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/test/TestUtils.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.sql.test
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.execution.streaming.MemoryStream
+
+object TestUtils {
+
+ /**
+ * Helper method to create typed MemoryStreams intended for use in unit tests.
+ * @param sqlContext The SQLContext.
+ * @param streamType The type of memory stream to create. This string is the `Name`
+ * property of the dotnet type.
+ * @return A typed MemoryStream.
+ */
+ def createMemoryStream(implicit sqlContext: SQLContext, streamType: String): MemoryStream[_] = {
+ import sqlContext.implicits._
+
+ streamType match {
+ case "Int32" => MemoryStream[Int]
+ case "String" => MemoryStream[String]
+ case _ => throw new Exception(s"$streamType not supported")
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/util/dotnet/Utils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/util/dotnet/Utils.scala
new file mode 100644
index 000000000..f9400789a
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/util/dotnet/Utils.scala
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.util.dotnet
+
+import java.io._
+import java.nio.file.attribute.PosixFilePermission
+import java.nio.file.attribute.PosixFilePermission._
+import java.nio.file.{FileSystems, Files}
+import java.util.{Timer, TimerTask}
+
+import org.apache.commons.compress.archivers.zip.{ZipArchiveEntry, ZipArchiveOutputStream, ZipFile}
+import org.apache.commons.io.{FileUtils, IOUtils}
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK
+
+import scala.collection.JavaConverters._
+import scala.collection.Set
+
+/**
+ * Utility methods.
+ */
+object Utils extends Logging {
+ private val posixFilePermissions = Array(
+ OWNER_READ,
+ OWNER_WRITE,
+ OWNER_EXECUTE,
+ GROUP_READ,
+ GROUP_WRITE,
+ GROUP_EXECUTE,
+ OTHERS_READ,
+ OTHERS_WRITE,
+ OTHERS_EXECUTE)
+
+ val supportPosix: Boolean =
+ FileSystems.getDefault.supportedFileAttributeViews().contains("posix")
+
+ /**
+ * Compress all files under given directory into one zip file and drop it to the target directory
+ *
+ * @param sourceDir source directory to zip
+ * @param targetZipFile target zip file
+ */
+ def zip(sourceDir: File, targetZipFile: File): Unit = {
+ var fos: FileOutputStream = null
+ var zos: ZipArchiveOutputStream = null
+ try {
+ fos = new FileOutputStream(targetZipFile)
+ zos = new ZipArchiveOutputStream(fos)
+
+ val sourcePath = sourceDir.toPath
+ FileUtils.listFiles(sourceDir, null, true).asScala.foreach { file =>
+ var in: FileInputStream = null
+ try {
+ val path = file.toPath
+ val entry = new ZipArchiveEntry(sourcePath.relativize(path).toString)
+ if (supportPosix) {
+ entry.setUnixMode(
+ permissionsToMode(Files.getPosixFilePermissions(path).asScala)
+ | (if (entry.getName.endsWith(".exe")) 0x1ED else 0x1A4))
+ } else if (entry.getName.endsWith(".exe")) {
+ entry.setUnixMode(0x1ED) // 755
+ } else {
+ entry.setUnixMode(0x1A4) // 644
+ }
+ zos.putArchiveEntry(entry)
+
+ in = new FileInputStream(file)
+ IOUtils.copy(in, zos)
+ zos.closeArchiveEntry()
+ } finally {
+ IOUtils.closeQuietly(in)
+ }
+ }
+ } finally {
+ IOUtils.closeQuietly(zos)
+ IOUtils.closeQuietly(fos)
+ }
+ }
+
+ /**
+ * Unzip a file to the given directory
+ *
+ * @param file file to be unzipped
+ * @param targetDir target directory
+ */
+ def unzip(file: File, targetDir: File): Unit = {
+ var zipFile: ZipFile = null
+ try {
+ targetDir.mkdirs()
+ zipFile = new ZipFile(file)
+ zipFile.getEntries.asScala.foreach { entry =>
+ val targetFile = new File(targetDir, entry.getName)
+
+ if (targetFile.exists()) {
+ logWarning(
+ s"Target file/directory $targetFile already exists. Skip it for now. " +
+ s"Make sure this is expected.")
+ } else {
+ if (entry.isDirectory) {
+ targetFile.mkdirs()
+ } else {
+ targetFile.getParentFile.mkdirs()
+ val input = zipFile.getInputStream(entry)
+ val output = new FileOutputStream(targetFile)
+ IOUtils.copy(input, output)
+ IOUtils.closeQuietly(input)
+ IOUtils.closeQuietly(output)
+ if (supportPosix) {
+ val permissions = modeToPermissions(entry.getUnixMode)
+ // When run in Unix system, permissions will be empty, thus skip
+ // setting the empty permissions (which will empty the previous permissions).
+ if (permissions.nonEmpty) {
+ Files.setPosixFilePermissions(targetFile.toPath, permissions.asJava)
+ }
+ }
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("exception caught during decompression:" + e)
+ } finally {
+ ZipFile.closeQuietly(zipFile)
+ }
+ }
+
+ /**
+ * Exits the JVM, trying to do it nicely, otherwise doing it nastily.
+ *
+ * @param status the exit status, zero for OK, non-zero for error
+ * @param maxDelayMillis the maximum delay in milliseconds
+ */
+ def exit(status: Int, maxDelayMillis: Long) {
+ try {
+ logInfo(s"Utils.exit() with status: $status, maxDelayMillis: $maxDelayMillis")
+
+ // setup a timer, so if nice exit fails, the nasty exit happens
+ val timer = new Timer()
+ timer.schedule(new TimerTask() {
+
+ override def run() {
+ Runtime.getRuntime.halt(status)
+ }
+ }, maxDelayMillis)
+ // try to exit nicely
+ System.exit(status);
+ } catch {
+ // exit nastily if we have a problem
+ case _: Throwable => Runtime.getRuntime.halt(status)
+ } finally {
+ // should never get here
+ Runtime.getRuntime.halt(status)
+ }
+ }
+
+ /**
+ * Exits the JVM, trying to do it nicely, wait 1 second
+ *
+ * @param status the exit status, zero for OK, non-zero for error
+ */
+ def exit(status: Int): Unit = {
+ exit(status, 1000)
+ }
+
+ /**
+ * Normalize the Spark version by taking the first three numbers.
+ * For example:
+ * x.y.z => x.y.z
+ * x.y.z.xxx.yyy => x.y.z
+ * x.y => x.y
+ *
+ * @param version the Spark version to normalize
+ * @return Normalized Spark version.
+ */
+ def normalizeSparkVersion(version: String): String = {
+ version
+ .split('.')
+ .take(3)
+ .zipWithIndex
+ .map({
+ case (element, index) => {
+ index match {
+ case 2 => element.split("\\D+").lift(0).getOrElse("")
+ case _ => element
+ }
+ }
+ })
+ .mkString(".")
+ }
+
+ /**
+ * Validates the normalized spark version by verifying:
+ * - Spark version starts with sparkMajorMinorVersionPrefix.
+ * - If ignoreSparkPatchVersion is
+ * - true: valid
+ * - false: check if the spark version is in supportedSparkVersions.
+ * @param ignoreSparkPatchVersion Ignore spark patch version.
+ * @param sparkVersion The spark version.
+ * @param normalizedSparkVersion: The normalized spark version.
+ * @param supportedSparkMajorMinorVersionPrefix The spark major and minor version to validate against.
+ * @param supportedSparkVersions The set of supported spark versions.
+ */
+ def validateSparkVersions(
+ ignoreSparkPatchVersion: Boolean,
+ sparkVersion: String,
+ normalizedSparkVersion: String,
+ supportedSparkMajorMinorVersionPrefix: String,
+ supportedSparkVersions: Set[String]): Unit = {
+ if (!normalizedSparkVersion.startsWith(s"$supportedSparkMajorMinorVersionPrefix.")) {
+ throw new IllegalArgumentException(
+ s"Unsupported spark version used: '$sparkVersion'. " +
+ s"Normalized spark version used: '$normalizedSparkVersion'. " +
+ s"Supported spark major.minor version: '$supportedSparkMajorMinorVersionPrefix'.")
+ } else if (ignoreSparkPatchVersion) {
+ logWarning(
+ s"Ignoring spark patch version. Spark version used: '$sparkVersion'. " +
+ s"Normalized spark version used: '$normalizedSparkVersion'. " +
+ s"Spark major.minor prefix used: '$supportedSparkMajorMinorVersionPrefix'.")
+ } else if (!supportedSparkVersions(normalizedSparkVersion)) {
+ val supportedVersions = supportedSparkVersions.toSeq.sorted.mkString(", ")
+ throw new IllegalArgumentException(
+ s"Unsupported spark version used: '$sparkVersion'. " +
+ s"Normalized spark version used: '$normalizedSparkVersion'. " +
+ s"Supported versions: '$supportedVersions'.")
+ }
+ }
+
+ private[spark] def listZipFileEntries(file: File): Array[String] = {
+ var zipFile: ZipFile = null
+ try {
+ zipFile = new ZipFile(file)
+ zipFile.getEntries.asScala.map(_.getName).toArray
+ } finally {
+ ZipFile.closeQuietly(zipFile)
+ }
+ }
+
+ private[this] def permissionsToMode(permissions: Set[PosixFilePermission]): Int = {
+ posixFilePermissions.foldLeft(0) { (mode, perm) =>
+ (mode << 1) | (if (permissions.contains(perm)) 1 else 0)
+ }
+ }
+
+ private[this] def modeToPermissions(mode: Int): Set[PosixFilePermission] = {
+ posixFilePermissions.zipWithIndex
+ .filter { case (_, i) => (mode & (0x100 >>> i)) != 0 }
+ .map(_._1)
+ .toSet
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala
new file mode 100644
index 000000000..7088537e1
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+
+package org.apache.spark.api.dotnet
+
+import Extensions._
+import org.junit.Assert._
+import org.junit.{After, Before, Test}
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+@Test
+class DotnetBackendHandlerTest {
+ private var backend: DotnetBackend = _
+ private var tracker: JVMObjectTracker = _
+ private var handler: DotnetBackendHandler = _
+
+ @Before
+ def before(): Unit = {
+ backend = new DotnetBackend
+ tracker = new JVMObjectTracker
+ handler = new DotnetBackendHandler(backend, tracker)
+ }
+
+ @After
+ def after(): Unit = {
+ backend.close()
+ }
+
+ @Test
+ def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = {
+ val message = givenMessage(m => {
+ val serDe = new SerDe(null)
+ m.writeBoolean(true) // static method
+ serDe.writeInt(m, 1) // processId
+ serDe.writeInt(m, 1) // threadId
+ serDe.writeString(m, "DotnetHandler") // class name
+ serDe.writeString(m, "connectCallback") // command (method) name
+ m.writeInt(2) // number of arguments
+ m.writeByte('c') // 1st argument type (string)
+ serDe.writeString(m, "127.0.0.1") // 1st argument value (host)
+ m.writeByte('i') // 2nd argument type (integer)
+ m.writeInt(0) // 2nd argument value (port)
+ })
+
+ val payload = handler.handleBackendRequest(message)
+ val reply = new DataInputStream(new ByteArrayInputStream(payload))
+
+ assertEquals(
+ "status code must be successful.", 0, reply.readInt())
+ assertEquals('j', reply.readByte())
+ assertEquals(1, reply.readInt())
+ val trackingId = new String(reply.readNBytes(1), "UTF-8")
+ assertEquals("1", trackingId)
+ val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull
+ assertEquals(classOf[CallbackClient], client.getClass)
+ }
+
+ private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = {
+ val buffer = new ByteArrayOutputStream()
+ func(new DataOutputStream(buffer))
+ buffer.toByteArray
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala
new file mode 100644
index 000000000..445486bbd
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+
+package org.apache.spark.api.dotnet
+
+import org.junit.Assert._
+import org.junit.{After, Before, Test}
+
+import java.net.InetAddress
+
+@Test
+class DotnetBackendTest {
+ private var backend: DotnetBackend = _
+
+ @Before
+ def before(): Unit = {
+ backend = new DotnetBackend
+ }
+
+ @After
+ def after(): Unit = {
+ backend.close()
+ }
+
+ @Test
+ def shouldNotResetCallbackClient(): Unit = {
+ // Specifying port = 0 to select port dynamically.
+ backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0)
+
+ assertTrue(backend.callbackClient.isDefined)
+ assertThrows(classOf[Exception], () => {
+ backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0)
+ })
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala
new file mode 100644
index 000000000..c6904403b
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+
+package org.apache.spark.api.dotnet
+
+import java.io.DataInputStream
+
+private[dotnet] object Extensions {
+ implicit class DataInputStreamExt(stream: DataInputStream) {
+ def readNBytes(n: Int): Array[Byte] = {
+ val buf = new Array[Byte](n)
+ stream.readFully(buf)
+ buf
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala
new file mode 100644
index 000000000..43ae79005
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import org.junit.Test
+
+@Test
+class JVMObjectTrackerTest {
+
+ @Test
+ def shouldReleaseAllReferences(): Unit = {
+ val tracker = new JVMObjectTracker
+ val firstId = tracker.put(new Object)
+ val secondId = tracker.put(new Object)
+ val thirdId = tracker.put(new Object)
+
+ tracker.clear()
+
+ assert(tracker.get(firstId).isEmpty)
+ assert(tracker.get(secondId).isEmpty)
+ assert(tracker.get(thirdId).isEmpty)
+ }
+
+ @Test
+ def shouldResetCounter(): Unit = {
+ val tracker = new JVMObjectTracker
+ val firstId = tracker.put(new Object)
+ val secondId = tracker.put(new Object)
+
+ tracker.clear()
+
+ val thirdId = tracker.put(new Object)
+
+ assert(firstId.equals("1"))
+ assert(secondId.equals("2"))
+ assert(thirdId.equals("1"))
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala
new file mode 100644
index 000000000..41401d680
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala
@@ -0,0 +1,373 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import org.apache.spark.api.dotnet.Extensions._
+import org.apache.spark.sql.Row
+import org.junit.Assert._
+import org.junit.{Before, Test}
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.sql.Date
+import scala.collection.JavaConverters._
+
+@Test
+class SerDeTest {
+ private var serDe: SerDe = _
+ private var tracker: JVMObjectTracker = _
+
+ @Before
+ def before(): Unit = {
+ tracker = new JVMObjectTracker
+ serDe = new SerDe(tracker)
+ }
+
+ @Test
+ def shouldReadNull(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('n')
+ })
+
+ assertEquals(null, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldThrowForUnsupportedTypes(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('_')
+ })
+
+ assertThrows(classOf[IllegalArgumentException], () => {
+ serDe.readObject(input)
+ })
+ }
+
+ @Test
+ def shouldReadInteger(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('i')
+ in.writeInt(42)
+ })
+
+ assertEquals(42, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadLong(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('g')
+ in.writeLong(42)
+ })
+
+ assertEquals(42L, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadDouble(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('d')
+ in.writeDouble(42.42)
+ })
+
+ assertEquals(42.42, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadBoolean(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('b')
+ in.writeBoolean(true)
+ })
+
+ assertEquals(true, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadString(): Unit = {
+ val payload = "Spark Dotnet"
+ val input = givenInput(in => {
+ in.writeByte('c')
+ in.writeInt(payload.getBytes("UTF-8").length)
+ in.write(payload.getBytes("UTF-8"))
+ })
+
+ assertEquals(payload, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadMap(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('e') // map type descriptor
+ in.writeInt(3) // size
+ in.writeByte('i') // key type
+ in.writeInt(3) // number of keys
+ in.writeInt(11) // first key
+ in.writeInt(22) // second key
+ in.writeInt(33) // third key
+ in.writeInt(3) // number of values
+ in.writeByte('b') // first value type
+ in.writeBoolean(true) // first value
+ in.writeByte('d') // second value type
+ in.writeDouble(42.42) // second value
+ in.writeByte('n') // third type & value
+ })
+
+ assertEquals(
+ mapAsJavaMap(Map(
+ 11 -> true,
+ 22 -> 42.42,
+ 33 -> null)),
+ serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadEmptyMap(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('e') // map type descriptor
+ in.writeInt(0) // size
+ })
+
+ assertEquals(mapAsJavaMap(Map()), serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadBytesArray(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('r') // byte array type descriptor
+ in.writeInt(3) // length
+ in.write(Array[Byte](1, 2, 3)) // payload
+ })
+
+ assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]])
+ }
+
+ @Test
+ def shouldReadEmptyBytesArray(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('r') // byte array type descriptor
+ in.writeInt(0) // length
+ })
+
+ assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]])
+ }
+
+ @Test
+ def shouldReadEmptyList(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('l') // type descriptor
+ in.writeByte('i') // element type
+ in.writeInt(0) // length
+ })
+
+ assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]])
+ }
+
+ @Test
+ def shouldReadList(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('l') // type descriptor
+ in.writeByte('b') // element type
+ in.writeInt(3) // length
+ in.writeBoolean(true)
+ in.writeBoolean(false)
+ in.writeBoolean(true)
+ })
+
+ assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]])
+ }
+
+ @Test
+ def shouldThrowWhenReadingListWithUnsupportedType(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('l') // type descriptor
+ in.writeByte('_') // unsupported element type
+ })
+
+ assertThrows(classOf[IllegalArgumentException], () => {
+ serDe.readObject(input)
+ })
+ }
+
+ @Test
+ def shouldReadDate(): Unit = {
+ val input = givenInput(in => {
+ val date = "2020-12-31"
+ in.writeByte('D') // type descriptor
+ in.writeInt(date.getBytes("UTF-8").length) // date string size
+ in.write(date.getBytes("UTF-8"))
+ })
+
+ assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadObject(): Unit = {
+ val trackingObject = new Object
+ tracker.put(trackingObject)
+ val input = givenInput(in => {
+ val objectIndex = "1"
+ in.writeByte('j') // type descriptor
+ in.writeInt(objectIndex.getBytes("UTF-8").length) // size
+ in.write(objectIndex.getBytes("UTF-8"))
+ })
+
+ assertSame(trackingObject, serDe.readObject(input))
+ }
+
+ @Test
+ def shouldThrowWhenReadingNonTrackingObject(): Unit = {
+ val input = givenInput(in => {
+ val objectIndex = "42"
+ in.writeByte('j') // type descriptor
+ in.writeInt(objectIndex.getBytes("UTF-8").length) // size
+ in.write(objectIndex.getBytes("UTF-8"))
+ })
+
+ assertThrows(classOf[NoSuchElementException], () => {
+ serDe.readObject(input)
+ })
+ }
+
+ @Test
+ def shouldReadSparkRows(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('R') // type descriptor
+ in.writeInt(2) // number of rows
+ in.writeInt(1) // number of elements in 1st row
+ in.writeByte('i') // type of 1st element in 1st row
+ in.writeInt(11)
+ in.writeInt(3) // number of elements in 2st row
+ in.writeByte('b') // type of 1st element in 2nd row
+ in.writeBoolean(true)
+ in.writeByte('d') // type of 2nd element in 2nd row
+ in.writeDouble(42.24)
+ in.writeByte('g') // type of 3nd element in 2nd row
+ in.writeLong(99)
+ })
+
+ assertEquals(
+ seqAsJavaList(Seq(
+ Row.fromSeq(Seq(11)),
+ Row.fromSeq(Seq(true, 42.24, 99)))),
+ serDe.readObject(input))
+ }
+
+ @Test
+ def shouldReadArrayOfObjects(): Unit = {
+ val input = givenInput(in => {
+ in.writeByte('O') // type descriptor
+ in.writeInt(2) // number of elements
+ in.writeByte('i') // type of 1st element
+ in.writeInt(42)
+ in.writeByte('b') // type of 2nd element
+ in.writeBoolean(true)
+ })
+
+ assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]])
+ }
+
+ @Test
+ def shouldWriteNull(): Unit = {
+ val in = whenOutput(out => {
+ serDe.writeObject(out, null)
+ serDe.writeObject(out, Unit)
+ })
+
+ assertEquals(in.readByte(), 'n')
+ assertEquals(in.readByte(), 'n')
+ assertEndOfStream(in)
+ }
+
+ @Test
+ def shouldWriteString(): Unit = {
+ val sparkDotnet = "Spark Dotnet"
+ val in = whenOutput(out => {
+ serDe.writeObject(out, sparkDotnet)
+ })
+
+ assertEquals(in.readByte(), 'c') // object type
+ assertEquals(in.readInt(), sparkDotnet.length) // length
+ assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8"))
+ assertEndOfStream(in)
+ }
+
+ @Test
+ def shouldWritePrimitiveTypes(): Unit = {
+ val in = whenOutput(out => {
+ serDe.writeObject(out, 42.24f.asInstanceOf[Object])
+ serDe.writeObject(out, 42L.asInstanceOf[Object])
+ serDe.writeObject(out, 42.asInstanceOf[Object])
+ serDe.writeObject(out, true.asInstanceOf[Object])
+ })
+
+ assertEquals(in.readByte(), 'd')
+ assertEquals(in.readDouble(), 42.24F, 0.000001)
+ assertEquals(in.readByte(), 'g')
+ assertEquals(in.readLong(), 42L)
+ assertEquals(in.readByte(), 'i')
+ assertEquals(in.readInt(), 42)
+ assertEquals(in.readByte(), 'b')
+ assertEquals(in.readBoolean(), true)
+ assertEndOfStream(in)
+ }
+
+ @Test
+ def shouldWriteDate(): Unit = {
+ val date = "2020-12-31"
+ val in = whenOutput(out => {
+ serDe.writeObject(out, Date.valueOf(date))
+ })
+
+ assertEquals(in.readByte(), 'D') // type
+ assertEquals(in.readInt(), 10) // size
+ assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content
+ }
+
+ @Test
+ def shouldWriteCustomObjects(): Unit = {
+ val customObject = new Object
+ val in = whenOutput(out => {
+ serDe.writeObject(out, customObject)
+ })
+
+ assertEquals(in.readByte(), 'j')
+ assertEquals(in.readInt(), 1)
+ assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8"))
+ assertSame(tracker.get("1").get, customObject)
+ }
+
+ @Test
+ def shouldWriteArrayOfCustomObjects(): Unit = {
+ val payload = Array(new Object, new Object)
+ val in = whenOutput(out => {
+ serDe.writeObject(out, payload)
+ })
+
+ assertEquals(in.readByte(), 'l') // array type
+ assertEquals(in.readByte(), 'j') // type of element in array
+ assertEquals(in.readInt(), 2) // array length
+ assertEquals(in.readInt(), 1) // size of 1st element's identifiers
+ assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element
+ assertEquals(in.readInt(), 1) // size of 2nd element's identifier
+ assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element
+ assertSame(tracker.get("1").get, payload(0))
+ assertSame(tracker.get("2").get, payload(1))
+ }
+
+ private def givenInput(func: DataOutputStream => Unit): DataInputStream = {
+ val buffer = new ByteArrayOutputStream()
+ val out = new DataOutputStream(buffer)
+ func(out)
+ new DataInputStream(new ByteArrayInputStream(buffer.toByteArray))
+ }
+
+ private def whenOutput = givenInput _
+
+ private def assertEndOfStream (in: DataInputStream): Unit = {
+ assertEquals(-1, in.read())
+ }
+}
diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala
new file mode 100644
index 000000000..9d4379436
--- /dev/null
+++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.util.dotnet
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK
+import org.junit.Assert.{assertEquals, assertThrows}
+import org.junit.Test
+
+@Test
+class UtilsTest {
+
+ @Test
+ def shouldIgnorePatchVersion(): Unit = {
+ val sparkVersion = "3.4.0"
+ val sparkMajorMinorVersionPrefix = "3.4"
+ val supportedSparkVersions = Set[String]("3.4.0")
+
+ Utils.validateSparkVersions(
+ true,
+ sparkVersion,
+ Utils.normalizeSparkVersion(sparkVersion),
+ sparkMajorMinorVersionPrefix,
+ supportedSparkVersions)
+ }
+
+ @Test
+ def shouldThrowForUnsupportedVersion(): Unit = {
+ val sparkVersion = "3.4.0"
+ val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion)
+ val sparkMajorMinorVersionPrefix = "3.4"
+ val supportedSparkVersions = Set[String]("3.4.0")
+
+ val exception = assertThrows(
+ classOf[IllegalArgumentException],
+ () => {
+ Utils.validateSparkVersions(
+ false,
+ sparkVersion,
+ normalizedSparkVersion,
+ sparkMajorMinorVersionPrefix,
+ supportedSparkVersions)
+ })
+
+ assertEquals(
+ s"Unsupported spark version used: '$sparkVersion'. " +
+ s"Normalized spark version used: '$normalizedSparkVersion'. " +
+ s"Supported versions: '${supportedSparkVersions.toSeq.sorted.mkString(", ")}'.",
+ exception.getMessage)
+ }
+
+ @Test
+ def shouldThrowForUnsupportedMajorMinorVersion(): Unit = {
+ val sparkVersion = "2.4.4"
+ val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion)
+ val sparkMajorMinorVersionPrefix = "3.4"
+ val supportedSparkVersions = Set[String]("3.4.0")
+
+ val exception = assertThrows(
+ classOf[IllegalArgumentException],
+ () => {
+ Utils.validateSparkVersions(
+ false,
+ sparkVersion,
+ normalizedSparkVersion,
+ sparkMajorMinorVersionPrefix,
+ supportedSparkVersions)
+ })
+
+ assertEquals(
+ s"Unsupported spark version used: '$sparkVersion'. " +
+ s"Normalized spark version used: '$normalizedSparkVersion'. " +
+ s"Supported spark major.minor version: '$sparkMajorMinorVersionPrefix'.",
+ exception.getMessage)
+ }
+}
diff --git a/src/scala/pom.xml b/src/scala/pom.xml
index 890af92b8..53fe6d4f2 100644
--- a/src/scala/pom.xml
+++ b/src/scala/pom.xml
@@ -15,6 +15,8 @@
microsoft-spark-3-0
microsoft-spark-3-1
microsoft-spark-3-2
+ microsoft-spark-3-3
+ microsoft-spark-3-4