From b67de69fe2e6d06c7ce6c58e71001dc4084a51ee Mon Sep 17 00:00:00 2001 From: Steve Suh Date: Thu, 8 Apr 2021 11:55:42 -0700 Subject: [PATCH 1/8] remove 2.3 jar --- src/scala/microsoft-spark-2-3/pom.xml | 77 ---- .../spark/api/dotnet/CallbackClient.scala | 72 ---- .../spark/api/dotnet/CallbackConnection.scala | 112 ----- .../spark/api/dotnet/DotnetBackend.scala | 114 ------ .../api/dotnet/DotnetBackendHandler.scala | 329 --------------- .../spark/api/dotnet/DotnetException.scala | 13 - .../apache/spark/api/dotnet/DotnetRDD.scala | 30 -- .../spark/api/dotnet/JVMObjectTracker.scala | 54 --- .../spark/api/dotnet/JvmBridgeUtils.scala | 33 -- .../org/apache/spark/api/dotnet/SerDe.scala | 386 ------------------ .../apache/spark/api/dotnet/ThreadPool.scala | 72 ---- .../spark/deploy/dotnet/DotnetRunner.scala | 286 ------------- .../spark/internal/config/dotnet/Dotnet.scala | 18 - .../spark/sql/api/dotnet/SQLUtils.scala | 37 -- .../org/apache/spark/sql/test/TestUtils.scala | 30 -- .../org/apache/spark/util/dotnet/Utils.scala | 255 ------------ .../api/dotnet/DotnetBackendHandlerTest.scala | 61 --- .../spark/api/dotnet/DotnetBackendTest.scala | 43 -- .../apache/spark/api/dotnet/Extensions.scala | 19 - .../api/dotnet/JVMObjectTrackerTest.scala | 42 -- .../apache/spark/api/dotnet/SerDeTest.scala | 386 ------------------ .../apache/spark/util/dotnet/UtilsTest.scala | 85 ---- src/scala/pom.xml | 1 - 23 files changed, 2555 deletions(-) delete mode 100644 src/scala/microsoft-spark-2-3/pom.xml delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/sql/test/TestUtils.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/util/dotnet/Utils.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala delete mode 100644 src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala diff --git a/src/scala/microsoft-spark-2-3/pom.xml b/src/scala/microsoft-spark-2-3/pom.xml deleted file mode 100644 index 8cde04b6e..000000000 --- a/src/scala/microsoft-spark-2-3/pom.xml +++ /dev/null @@ -1,77 +0,0 @@ - - 4.0.0 - - com.microsoft.scala - microsoft-spark - ${microsoft-spark.version} - - microsoft-spark-2-3_2.11 - 2019 - - UTF-8 - 2.11.8 - 2.11 - 2.3.4 - - - - - org.scala-lang - scala-library - ${scala.version} - - - org.apache.spark - spark-core_${scala.binary.version} - ${spark.version} - provided - - - org.apache.spark - spark-sql_${scala.binary.version} - ${spark.version} - provided - - - junit - junit - 4.13.1 - test - - - org.specs - specs - 1.2.5 - test - - - - - src/main/scala - src/test/scala - - - org.scala-tools - maven-scala-plugin - 2.15.2 - - - - compile - testCompile - - - - - ${scala.version} - - -target:jvm-1.8 - -deprecation - -feature - - - - - - diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala deleted file mode 100644 index aea355dfa..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import java.io.DataOutputStream - -import org.apache.spark.internal.Logging - -import scala.collection.mutable.Queue - -/** - * CallbackClient is used to communicate with the Dotnet CallbackServer. - * The client manages and maintains a pool of open CallbackConnections. - * Any callback request is delegated to a new CallbackConnection or - * unused CallbackConnection. - * @param address The address of the Dotnet CallbackServer - * @param port The port of the Dotnet CallbackServer - */ -class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging { - private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]() - - private[this] var isShutdown: Boolean = false - - final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit = - getOrCreateConnection() match { - case Some(connection) => - try { - connection.send(callbackId, writeBody) - addConnection(connection) - } catch { - case e: Exception => - logError(s"Error calling callback [callback id = $callbackId].", e) - connection.close() - throw e - } - case None => throw new Exception("Unable to get or create connection.") - } - - private def getOrCreateConnection(): Option[CallbackConnection] = synchronized { - if (isShutdown) { - logInfo("Cannot get or create connection while client is shutdown.") - return None - } - - if (connectionPool.nonEmpty) { - return Some(connectionPool.dequeue()) - } - - Some(new CallbackConnection(serDe, address, port)) - } - - private def addConnection(connection: CallbackConnection): Unit = synchronized { - assert(connection != null) - connectionPool.enqueue(connection) - } - - def shutdown(): Unit = synchronized { - if (isShutdown) { - logInfo("Shutdown called, but already shutdown.") - return - } - - logInfo("Shutting down.") - connectionPool.foreach(_.close) - connectionPool.clear - isShutdown = true - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala deleted file mode 100644 index 604cf029b..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import java.io.{ByteArrayOutputStream, Closeable, DataInputStream, DataOutputStream} -import java.net.Socket - -import org.apache.spark.internal.Logging - -/** - * CallbackConnection is used to process the callback communication - * between the JVM and Dotnet. It uses a TCP socket to communicate with - * the Dotnet CallbackServer and the socket is expected to be reused. - * @param address The address of the Dotnet CallbackServer - * @param port The port of the Dotnet CallbackServer - */ -class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging { - private[this] val socket: Socket = new Socket(address, port) - private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream) - private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream) - - def send( - callbackId: Int, - writeBody: (DataOutputStream, SerDe) => Unit): Unit = { - logInfo(s"Calling callback [callback id = $callbackId] ...") - - try { - serDe.writeInt(outputStream, CallbackFlags.CALLBACK) - serDe.writeInt(outputStream, callbackId) - - val byteArrayOutputStream = new ByteArrayOutputStream() - writeBody(new DataOutputStream(byteArrayOutputStream), serDe) - serDe.writeInt(outputStream, byteArrayOutputStream.size) - byteArrayOutputStream.writeTo(outputStream); - } catch { - case e: Exception => { - throw new Exception("Error writing to stream.", e) - } - } - - logInfo(s"Signaling END_OF_STREAM.") - try { - serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) - outputStream.flush() - - val endOfStreamResponse = readFlag(inputStream) - endOfStreamResponse match { - case CallbackFlags.END_OF_STREAM => - logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.") - case _ => { - throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " + - s"Received: $endOfStreamResponse") - } - } - } catch { - case e: Exception => { - throw new Exception("Error while verifying end of stream.", e) - } - } - } - - def close(): Unit = { - try { - serDe.writeInt(outputStream, CallbackFlags.CLOSE) - outputStream.flush() - } catch { - case e: Exception => logInfo("Unable to send close to .NET callback server.", e) - } - - close(socket) - close(outputStream) - close(inputStream) - } - - private def close(s: Socket): Unit = { - try { - assert(s != null) - s.close() - } catch { - case e: Exception => logInfo("Unable to close socket.", e) - } - } - - private def close(c: Closeable): Unit = { - try { - assert(c != null) - c.close() - } catch { - case e: Exception => logInfo("Unable to close closeable.", e) - } - } - - private def readFlag(inputStream: DataInputStream): Int = { - val callbackFlag = serDe.readInt(inputStream) - if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) { - val exceptionMessage = serDe.readString(inputStream) - throw new DotnetException(exceptionMessage) - } - callbackFlag - } - - private object CallbackFlags { - val CLOSE: Int = -1 - val CALLBACK: Int = -2 - val DOTNET_EXCEPTION_THROWN: Int = -3 - val END_OF_STREAM: Int = -4 - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala deleted file mode 100644 index 1d8215d44..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import java.net.InetSocketAddress -import java.util.concurrent.TimeUnit - -import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.socket.SocketChannel -import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} -import io.netty.handler.codec.LengthFieldBasedFrameDecoder -import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_NUM_BACKEND_THREADS -import org.apache.spark.{SparkConf, SparkEnv} - -/** - * Netty server that invokes JVM calls based upon receiving messages from .NET. - * The implementation mirrors the RBackend. - * - */ -class DotnetBackend extends Logging { - self => // for accessing the this reference in inner class(ChannelInitializer) - private[this] var channelFuture: ChannelFuture = _ - private[this] var bootstrap: ServerBootstrap = _ - private[this] var bossGroup: EventLoopGroup = _ - private[this] val objectTracker = new JVMObjectTracker - - @volatile - private[dotnet] var callbackClient: Option[CallbackClient] = None - - def init(portNumber: Int): Int = { - val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) - val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS) - logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.") - bossGroup = new NioEventLoopGroup(numBackendThreads) - val workerGroup = bossGroup - - bootstrap = new ServerBootstrap() - .group(bossGroup, workerGroup) - .channel(classOf[NioServerSocketChannel]) - - bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { - def initChannel(ch: SocketChannel): Unit = { - ch.pipeline() - .addLast("encoder", new ByteArrayEncoder()) - .addLast( - "frameDecoder", - // maxFrameLength = 2G - // lengthFieldOffset = 0 - // lengthFieldLength = 4 - // lengthAdjustment = 0 - // initialBytesToStrip = 4, i.e. strip out the length field itself - new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) - .addLast("decoder", new ByteArrayDecoder()) - .addLast("handler", new DotnetBackendHandler(self, objectTracker)) - } - }) - - channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber)) - channelFuture.syncUninterruptibly() - channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort - } - - private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { - callbackClient = callbackClient match { - case Some(_) => throw new Exception("Callback client already set.") - case None => - logInfo(s"Connecting to a callback server at $address:$port") - Some(new CallbackClient(new SerDe(objectTracker), address, port)) - } - } - - private[dotnet] def shutdownCallbackClient(): Unit = synchronized { - callbackClient match { - case Some(client) => client.shutdown() - case None => logInfo("Callback server has already been shutdown.") - } - callbackClient = None - } - - def run(): Unit = { - channelFuture.channel.closeFuture().syncUninterruptibly() - } - - def close(): Unit = { - if (channelFuture != null) { - // close is a local operation and should finish within milliseconds; timeout just to be safe - channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) - channelFuture = null - } - if (bootstrap != null && bootstrap.config().group() != null) { - bootstrap.config().group().shutdownGracefully() - } - if (bootstrap != null && bootstrap.config().childGroup() != null) { - bootstrap.config().childGroup().shutdownGracefully() - } - bootstrap = null - - objectTracker.clear() - - // Send close to .NET callback server. - shutdownCallbackClient() - - // Shutdown the thread pool whose executors could still be running. - ThreadPool.shutdown() - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala deleted file mode 100644 index 298d786ea..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala +++ /dev/null @@ -1,329 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import scala.language.existentials - -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils - -/** - * Handler for DotnetBackend. - * This implementation is similar to RBackendHandler. - */ -class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker) - extends SimpleChannelInboundHandler[Array[Byte]] - with Logging { - - private[this] val serDe = new SerDe(objectsTracker) - - override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { - val reply = handleBackendRequest(msg) - ctx.write(reply) - } - - override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { - ctx.flush() - } - - def handleBackendRequest(msg: Array[Byte]): Array[Byte] = { - val bis = new ByteArrayInputStream(msg) - val dis = new DataInputStream(bis) - - val bos = new ByteArrayOutputStream() - val dos = new DataOutputStream(bos) - - // First bit is isStatic - val isStatic = serDe.readBoolean(dis) - val processId = serDe.readInt(dis) - val threadId = serDe.readInt(dis) - val objId = serDe.readString(dis) - val methodName = serDe.readString(dis) - val numArgs = serDe.readInt(dis) - - if (objId == "DotnetHandler") { - methodName match { - case "stopBackend" => - serDe.writeInt(dos, 0) - serDe.writeType(dos, "void") - server.close() - case "rm" => - try { - val t = serDe.readObjectType(dis) - assert(t == 'c') - val objToRemove = serDe.readString(dis) - objectsTracker.remove(objToRemove) - serDe.writeInt(dos, 0) - serDe.writeObject(dos, null) - } catch { - case e: Exception => - logError(s"Removing $objId failed", e) - serDe.writeInt(dos, -1) - } - case "rmThread" => - try { - assert(serDe.readObjectType(dis) == 'i') - val processId = serDe.readInt(dis) - assert(serDe.readObjectType(dis) == 'i') - val threadToDelete = serDe.readInt(dis) - val result = ThreadPool.tryDeleteThread(processId, threadToDelete) - serDe.writeInt(dos, 0) - serDe.writeObject(dos, result.asInstanceOf[AnyRef]) - } catch { - case e: Exception => - logError(s"Removing thread $threadId failed", e) - serDe.writeInt(dos, -1) - } - case "connectCallback" => - assert(serDe.readObjectType(dis) == 'c') - val address = serDe.readString(dis) - assert(serDe.readObjectType(dis) == 'i') - val port = serDe.readInt(dis) - server.setCallbackClient(address, port) - serDe.writeInt(dos, 0) - serDe.writeType(dos, "void") - case "closeCallback" => - logInfo("Requesting to close callback client") - server.shutdownCallbackClient() - serDe.writeInt(dos, 0) - serDe.writeType(dos, "void") - - case _ => dos.writeInt(-1) - } - } else { - ThreadPool - .run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) - } - - bos.toByteArray - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - // Skip logging the exception message if the connection was disconnected from - // the .NET side so that .NET side doesn't have to explicitly close the connection via - // "stopBackend." Note that an exception is still thrown if the exit status is non-zero, - // so skipping this kind of exception message does not affect the debugging. - if (!cause.getMessage.contains( - "An existing connection was forcibly closed by the remote host")) { - logError("Exception caught: ", cause) - } - - // Close the connection when an exception is raised. - ctx.close() - } - - def handleMethodCall( - isStatic: Boolean, - objId: String, - methodName: String, - numArgs: Int, - dis: DataInputStream, - dos: DataOutputStream): Unit = { - var obj: Object = null - var args: Array[java.lang.Object] = null - var methods: Array[java.lang.reflect.Method] = null - - try { - val cls = if (isStatic) { - Utils.classForName(objId) - } else { - objectsTracker.get(objId) match { - case None => throw new IllegalArgumentException("Object not found " + objId) - case Some(o) => - obj = o - o.getClass - } - } - - args = readArgs(numArgs, dis) - methods = cls.getMethods - - val selectedMethods = methods.filter(m => m.getName == methodName) - if (selectedMethods.length > 0) { - val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args) - - if (index.isEmpty) { - logWarning( - s"cannot find matching method ${cls}.$methodName. " - + s"Candidates are:") - selectedMethods.foreach { method => - logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") - } - throw new Exception(s"No matched method found for $cls.$methodName") - } - - val ret = selectedMethods(index.get).invoke(obj, args: _*) - - // Write status bit - serDe.writeInt(dos, 0) - serDe.writeObject(dos, ret.asInstanceOf[AnyRef]) - } else if (methodName == "") { - // methodName should be "" for constructor - val ctor = cls.getConstructors.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - }.head - - val obj = ctor.newInstance(args: _*) - - serDe.writeInt(dos, 0) - serDe.writeObject(dos, obj.asInstanceOf[AnyRef]) - } else { - throw new IllegalArgumentException( - "invalid method " + methodName + " for object " + objId) - } - } catch { - case e: Throwable => - val jvmObj = objectsTracker.get(objId) - val jvmObjName = jvmObj match { - case Some(jObj) => jObj.getClass.getName - case None => "NullObject" - } - val argsStr = args - .map(arg => { - if (arg != null) { - s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" - } else { - "[Value: NULL]" - } - }) - .mkString(", ") - - logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") - - if (methods != null) { - logDebug(s"All methods for $jvmObjName:") - methods.foreach(m => logDebug(m.toString)) - } - - serDe.writeInt(dos, -1) - serDe.writeString(dos, Utils.exceptionString(e.getCause)) - } - } - - // Read a number of arguments from the data input stream - def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { - (0 until numArgs).map { arg => - serDe.readObject(dis) - }.toArray - } - - // Checks if the arguments passed in args matches the parameter types. - // NOTE: Currently we do exact match. We may add type conversions later. - def matchMethod( - numArgs: Int, - args: Array[java.lang.Object], - parameterTypes: Array[Class[_]]): Boolean = { - if (parameterTypes.length != numArgs) { - return false - } - - for (i <- 0 until numArgs) { - val parameterType = parameterTypes(i) - var parameterWrapperType = parameterType - - // Convert native parameters to Object types as args is Array[Object] here - if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Long] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType - } - } - - if (!parameterWrapperType.isInstance(args(i))) { - // non primitive types - if (!parameterType.isPrimitive && args(i) != null) { - return false - } - - // primitive types - if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) { - return false - } - } - } - - true - } - - // Find a matching method signature in an array of signatures of constructors - // or methods of the same name according to the passed arguments. Arguments - // may be converted in order to match a signature. - // - // Note that in Java reflection, constructors and normal methods are of different - // classes, and share no parent class that provides methods for reflection uses. - // There is no unified way to handle them in this function. So an array of signatures - // is passed in instead of an array of candidate constructors or methods. - // - // Returns an Option[Int] which is the index of the matched signature in the array. - def findMatchedSignature( - parameterTypesOfMethods: Array[Array[Class[_]]], - args: Array[Object]): Option[Int] = { - val numArgs = args.length - - for (index <- parameterTypesOfMethods.indices) { - val parameterTypes = parameterTypesOfMethods(index) - - if (parameterTypes.length == numArgs) { - var argMatched = true - var i = 0 - while (i < numArgs && argMatched) { - val parameterType = parameterTypes(i) - - if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { - // The case that the parameter type is a Scala Seq and the argument - // is a Java array is considered matching. The array will be converted - // to a Seq later if this method is matched. - } else { - var parameterWrapperType = parameterType - - // Convert native parameters to Object types as args is Array[Object] here - if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Long] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType - } - } - if ((parameterType.isPrimitive || args(i) != null) && - !parameterWrapperType.isInstance(args(i))) { - argMatched = false - } - } - - i = i + 1 - } - - if (argMatched) { - // For now, we return the first matching method. - // TODO: find best method in matching methods. - - // Convert args if needed - val parameterTypes = parameterTypesOfMethods(index) - - for (i <- 0 until numArgs) { - if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { - // Convert a Java array to scala Seq - args(i) = args(i).asInstanceOf[Array[_]].toSeq - } - } - - return Some(index) - } - } - } - None - } - - def logError(id: String, e: Exception): Unit = {} -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala deleted file mode 100644 index c70d16b03..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -class DotnetException(message: String, cause: Throwable) - extends Exception(message, cause) { - - def this(message: String) = this(message, null) -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala deleted file mode 100644 index f5277c215..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import org.apache.spark.SparkContext -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python._ -import org.apache.spark.rdd.RDD - -object DotnetRDD { - def createPythonRDD( - parent: RDD[_], - func: PythonFunction, - preservePartitoning: Boolean): PythonRDD = { - new PythonRDD(parent, func, preservePartitoning) - } - - def createJavaRDDFromArray( - sc: SparkContext, - arr: Array[Array[Byte]], - numSlices: Int): JavaRDD[Array[Byte]] = { - JavaRDD.fromRDD(sc.parallelize(arr, numSlices)) - } - - def toJavaRDD(rdd: RDD[_]): JavaRDD[_] = JavaRDD.fromRDD(rdd) -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala deleted file mode 100644 index aceb58c01..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import scala.collection.mutable.HashMap - -/** - * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. - */ -private[dotnet] class JVMObjectTracker { - - // Multiple threads may access objMap and increase objCounter. Because get method return Option, - // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. - private[this] val objMap = new HashMap[String, Object] - private[this] var objCounter: Int = 1 - - def getObject(id: String): Object = { - synchronized { - objMap(id) - } - } - - def get(id: String): Option[Object] = { - synchronized { - objMap.get(id) - } - } - - def put(obj: Object): String = { - synchronized { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } - } - - def remove(id: String): Option[Object] = { - synchronized { - objMap.remove(id) - } - } - - def clear(): Unit = { - synchronized { - objMap.clear() - objCounter = 1 - } - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala deleted file mode 100644 index 06a476f67..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.sql.api.dotnet - -import org.apache.spark.SparkConf - -/* - * Utils for JvmBridge. - */ -object JvmBridgeUtils { - def getKeyValuePairAsString(kvp: (String, String)): String = { - return kvp._1 + "=" + kvp._2 - } - - def getKeyValuePairArrayAsString(kvpArray: Array[(String, String)]): String = { - val sb = new StringBuilder - - for (kvp <- kvpArray) { - sb.append(getKeyValuePairAsString(kvp)) - sb.append(";") - } - - sb.toString - } - - def getSparkConfAsString(sparkConf: SparkConf): String = { - getKeyValuePairArrayAsString(sparkConf.getAll) - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala deleted file mode 100644 index 44cad97c1..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ /dev/null @@ -1,386 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import java.io.{DataInputStream, DataOutputStream} -import java.nio.charset.StandardCharsets -import java.sql.{Date, Time, Timestamp} - -import org.apache.spark.sql.Row - -import scala.collection.JavaConverters._ - -/** - * Class responsible for serialization and deserialization between CLR & JVM. - * This implementation of methods is mostly identical to the SerDe implementation in R. - */ -class SerDe(val tracker: JVMObjectTracker) { - def readObjectType(dis: DataInputStream): Char = { - dis.readByte().toChar - } - - def readObject(dis: DataInputStream): Object = { - val dataType = readObjectType(dis) - readTypedObject(dis, dataType) - } - - private def readTypedObject(dis: DataInputStream, dataType: Char): Object = { - dataType match { - case 'n' => null - case 'i' => new java.lang.Integer(readInt(dis)) - case 'g' => new java.lang.Long(readLong(dis)) - case 'd' => new java.lang.Double(readDouble(dis)) - case 'b' => new java.lang.Boolean(readBoolean(dis)) - case 'c' => readString(dis) - case 'e' => readMap(dis) - case 'r' => readBytes(dis) - case 'l' => readList(dis) - case 'D' => readDate(dis) - case 't' => readTime(dis) - case 'j' => tracker.getObject(readString(dis)) - case 'R' => readRowArr(dis) - case 'O' => readObjectArr(dis) - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") - } - } - - private def readBytes(in: DataInputStream): Array[Byte] = { - val len = readInt(in) - val out = new Array[Byte](len) - in.readFully(out) - out - } - - def readInt(in: DataInputStream): Int = { - in.readInt() - } - - private def readLong(in: DataInputStream): Long = { - in.readLong() - } - - private def readDouble(in: DataInputStream): Double = { - in.readDouble() - } - - private def readStringBytes(in: DataInputStream, len: Int): String = { - val bytes = new Array[Byte](len) - in.readFully(bytes) - val str = new String(bytes, "UTF-8") - str - } - - def readString(in: DataInputStream): String = { - val len = in.readInt() - readStringBytes(in, len) - } - - def readBoolean(in: DataInputStream): Boolean = { - in.readBoolean() - } - - private def readDate(in: DataInputStream): Date = { - Date.valueOf(readString(in)) - } - - private def readTime(in: DataInputStream): Timestamp = { - val seconds = in.readDouble() - val sec = Math.floor(seconds).toLong - val t = new Timestamp(sec * 1000L) - t.setNanos(((seconds - sec) * 1e9).toInt) - t - } - - private def readRow(in: DataInputStream): Row = { - val len = readInt(in) - Row.fromSeq((0 until len).map(_ => readObject(in))) - } - - private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { - val len = readInt(in) - (0 until len).map(_ => readBytes(in)).toArray - } - - private def readIntArr(in: DataInputStream): Array[Int] = { - val len = readInt(in) - (0 until len).map(_ => readInt(in)).toArray - } - - private def readLongArr(in: DataInputStream): Array[Long] = { - val len = readInt(in) - (0 until len).map(_ => readLong(in)).toArray - } - - private def readDoubleArr(in: DataInputStream): Array[Double] = { - val len = readInt(in) - (0 until len).map(_ => readDouble(in)).toArray - } - - private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { - val len = readInt(in) - (0 until len).map(_ => readDoubleArr(in)).toArray - } - - private def readBooleanArr(in: DataInputStream): Array[Boolean] = { - val len = readInt(in) - (0 until len).map(_ => readBoolean(in)).toArray - } - - private def readStringArr(in: DataInputStream): Array[String] = { - val len = readInt(in) - (0 until len).map(_ => readString(in)).toArray - } - - private def readRowArr(in: DataInputStream): java.util.List[Row] = { - val len = readInt(in) - (0 until len).map(_ => readRow(in)).toList.asJava - } - - private def readObjectArr(in: DataInputStream): Seq[Any] = { - val len = readInt(in) - (0 until len).map(_ => readObject(in)) - } - - private def readList(dis: DataInputStream): Array[_] = { - val arrType = readObjectType(dis) - arrType match { - case 'i' => readIntArr(dis) - case 'g' => readLongArr(dis) - case 'c' => readStringArr(dis) - case 'd' => readDoubleArr(dis) - case 'A' => readDoubleArrArr(dis) - case 'b' => readBooleanArr(dis) - case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) - case 'r' => readBytesArr(dis) - case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") - } - } - - private def readMap(in: DataInputStream): java.util.Map[Object, Object] = { - val len = readInt(in) - if (len > 0) { - val keysType = readObjectType(in) - val keysLen = readInt(in) - val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) - - val valuesLen = readInt(in) - val values = (0 until valuesLen).map(_ => { - val valueType = readObjectType(in) - readTypedObject(in, valueType) - }) - keys.zip(values).toMap.asJava - } else { - new java.util.HashMap[Object, Object]() - } - } - - // Using the same mapping as SparkR implementation for now - // Methods to write out data from Java to .NET. - // - // Type mapping from Java to .NET: - // - // void -> NULL - // Int -> integer - // String -> character - // Boolean -> logical - // Float -> double - // Double -> double - // Long -> long - // Array[Byte] -> raw - // Date -> Date - // Time -> POSIXct - // - // Array[T] -> list() - // Object -> jobj - - def writeType(dos: DataOutputStream, typeStr: String): Unit = { - typeStr match { - case "void" => dos.writeByte('n') - case "character" => dos.writeByte('c') - case "double" => dos.writeByte('d') - case "doublearray" => dos.writeByte('A') - case "long" => dos.writeByte('g') - case "integer" => dos.writeByte('i') - case "logical" => dos.writeByte('b') - case "date" => dos.writeByte('D') - case "time" => dos.writeByte('t') - case "raw" => dos.writeByte('r') - case "list" => dos.writeByte('l') - case "jobj" => dos.writeByte('j') - case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") - } - } - - def writeObject(dos: DataOutputStream, value: Object): Unit = { - if (value == null || value == Unit) { - writeType(dos, "void") - } else { - value.getClass.getName match { - case "java.lang.String" => - writeType(dos, "character") - writeString(dos, value.asInstanceOf[String]) - case "float" | "java.lang.Float" => - writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Float].toDouble) - case "double" | "java.lang.Double" => - writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Double]) - case "long" | "java.lang.Long" => - writeType(dos, "long") - writeLong(dos, value.asInstanceOf[Long]) - case "int" | "java.lang.Integer" => - writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Int]) - case "boolean" | "java.lang.Boolean" => - writeType(dos, "logical") - writeBoolean(dos, value.asInstanceOf[Boolean]) - case "java.sql.Date" => - writeType(dos, "date") - writeDate(dos, value.asInstanceOf[Date]) - case "java.sql.Time" => - writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Time]) - case "java.sql.Timestamp" => - writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Timestamp]) - case "[B" => - writeType(dos, "raw") - writeBytes(dos, value.asInstanceOf[Array[Byte]]) - // TODO: Types not handled right now include - // byte, char, short, float - - // Handle arrays - case "[Ljava.lang.String;" => - writeType(dos, "list") - writeStringArr(dos, value.asInstanceOf[Array[String]]) - case "[I" => - writeType(dos, "list") - writeIntArr(dos, value.asInstanceOf[Array[Int]]) - case "[J" => - writeType(dos, "list") - writeLongArr(dos, value.asInstanceOf[Array[Long]]) - case "[D" => - writeType(dos, "list") - writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) - case "[[D" => - writeType(dos, "list") - writeDoubleArrArr(dos, value.asInstanceOf[Array[Array[Double]]]) - case "[Z" => - writeType(dos, "list") - writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) - case "[[B" => - writeType(dos, "list") - writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) - case otherName => - // Handle array of objects - if (otherName.startsWith("[L")) { - val objArr = value.asInstanceOf[Array[Object]] - writeType(dos, "list") - writeType(dos, "jobj") - dos.writeInt(objArr.length) - objArr.foreach(o => writeJObj(dos, o)) - } else { - writeType(dos, "jobj") - writeJObj(dos, value) - } - } - } - } - - def writeInt(out: DataOutputStream, value: Int): Unit = { - out.writeInt(value) - } - - def writeLong(out: DataOutputStream, value: Long): Unit = { - out.writeLong(value) - } - - private def writeDouble(out: DataOutputStream, value: Double): Unit = { - out.writeDouble(value) - } - - private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { - out.writeBoolean(value) - } - - private def writeDate(out: DataOutputStream, value: Date): Unit = { - writeString(out, value.toString) - } - - private def writeTime(out: DataOutputStream, value: Time): Unit = { - out.writeDouble(value.getTime.toDouble / 1000.0) - } - - private def writeTime(out: DataOutputStream, value: Timestamp): Unit = { - out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) - } - - def writeString(out: DataOutputStream, value: String): Unit = { - val utf8 = value.getBytes(StandardCharsets.UTF_8) - val len = utf8.length - out.writeInt(len) - out.write(utf8, 0, len) - } - - private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { - out.writeInt(value.length) - out.write(value) - } - - def writeJObj(out: DataOutputStream, value: Object): Unit = { - val objId = tracker.put(value) - writeString(out, objId) - } - - private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { - writeType(out, "integer") - out.writeInt(value.length) - value.foreach(v => out.writeInt(v)) - } - - private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { - writeType(out, "long") - out.writeInt(value.length) - value.foreach(v => out.writeLong(v)) - } - - private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { - writeType(out, "double") - out.writeInt(value.length) - value.foreach(v => out.writeDouble(v)) - } - - private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { - writeType(out, "doublearray") - out.writeInt(value.length) - value.foreach(v => writeDoubleArr(out, v)) - } - - private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { - writeType(out, "logical") - out.writeInt(value.length) - value.foreach(v => writeBoolean(out, v)) - } - - private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { - writeType(out, "character") - out.writeInt(value.length) - value.foreach(v => writeString(out, v)) - } - - private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { - writeType(out, "raw") - out.writeInt(value.length) - value.foreach(v => writeBytes(out, v)) - } -} - -private object SerializationFormats { - val BYTE = "byte" - val STRING = "string" - val ROW = "row" -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala deleted file mode 100644 index 50551a7d9..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import java.util.concurrent.{ExecutorService, Executors} - -import scala.collection.mutable - -/** - * Pool of thread executors. There should be a 1-1 correspondence between C# threads - * and Java threads. - */ -object ThreadPool { - - /** - * Map from (processId, threadId) to corresponding executor. - */ - private val executors: mutable.HashMap[(Int, Int), ExecutorService] = - new mutable.HashMap[(Int, Int), ExecutorService]() - - /** - * Run some code on a particular thread. - * @param processId Integer id of the process. - * @param threadId Integer id of the thread. - * @param task Function to run on the thread. - */ - def run(processId: Int, threadId: Int, task: () => Unit): Unit = { - val executor = getOrCreateExecutor(processId, threadId) - val future = executor.submit(new Runnable { - override def run(): Unit = task() - }) - - future.get() - } - - /** - * Try to delete a particular thread. - * @param processId Integer id of the process. - * @param threadId Integer id of the thread. - * @return True if successful, false if thread does not exist. - */ - def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized { - executors.remove((processId, threadId)) match { - case Some(executorService) => - executorService.shutdown() - true - case None => false - } - } - - /** - * Shutdown any running ExecutorServices. - */ - def shutdown(): Unit = synchronized { - executors.foreach(_._2.shutdown()) - executors.clear() - } - - /** - * Get the executor if it exists, otherwise create a new one. - * @param processId Integer id of the process. - * @param threadId Integer id of the thread. - * @return The new or existing executor with the given id. - */ - private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized { - executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor) - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala deleted file mode 100644 index b11a34900..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.deploy.dotnet - -import java.io.File -import java.net.URI -import java.nio.file.attribute.PosixFilePermissions -import java.nio.file.{FileSystems, Files, Paths} -import java.util.Locale -import java.util.concurrent.{Semaphore, TimeUnit} - -import org.apache.commons.io.FilenameUtils -import org.apache.hadoop.fs.Path -import org.apache.spark -import org.apache.spark.api.dotnet.DotnetBackend -import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK -import org.apache.spark.util.dotnet.{Utils => DotnetUtils} -import org.apache.spark.util.{RedirectThread, Utils} -import org.apache.spark.{SecurityManager, SparkConf, SparkUserAppException} - -import scala.collection.JavaConverters._ -import scala.io.StdIn -import scala.util.Try - -/** - * DotnetRunner class used to launch Spark .NET applications using spark-submit. - * It executes .NET application as a subprocess and then has it connect back to - * the JVM to access system properties etc. - */ -object DotnetRunner extends Logging { - private val DEBUG_PORT = 5567 - private val supportedSparkMajorMinorVersionPrefix = "2.3" - private val supportedSparkVersions = Set[String]("2.3.0", "2.3.1", "2.3.2", "2.3.3", "2.3.4") - - val SPARK_VERSION = DotnetUtils.normalizeSparkVersion(spark.SPARK_VERSION) - - def main(args: Array[String]): Unit = { - if (args.length == 0) { - throw new IllegalArgumentException("At least one argument is expected.") - } - - DotnetUtils.validateSparkVersions( - sys.props - .getOrElse( - DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.key, - DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.defaultValue.get.toString) - .toBoolean, - spark.SPARK_VERSION, - SPARK_VERSION, - supportedSparkMajorMinorVersionPrefix, - supportedSparkVersions) - - val settings = initializeSettings(args) - - // Determines if this needs to be run in debug mode. - // In debug mode this runner will not launch a .NET process. - val runInDebugMode = settings._1 - @volatile var dotnetBackendPortNumber = settings._2 - var dotnetExecutable = "" - var otherArgs: Array[String] = null - - if (!runInDebugMode) { - if (args(0).toLowerCase(Locale.ROOT).endsWith(".zip")) { - var zipFileName = args(0) - val zipFileUri = Try(new URI(zipFileName)).getOrElse(new File(zipFileName).toURI) - val workingDir = new File("").getAbsoluteFile - val driverDir = new File(workingDir, FilenameUtils.getBaseName(zipFileUri.getPath())) - - // Standalone cluster mode where .NET application is remotely located. - if (zipFileUri.getScheme() != "file") { - zipFileName = downloadDriverFile(zipFileName, workingDir.getAbsolutePath).getName - } - - logInfo(s"Unzipping .NET driver $zipFileName to $driverDir") - DotnetUtils.unzip(new File(zipFileName), driverDir) - - // Reuse windows-specific formatting in PythonRunner. - dotnetExecutable = PythonRunner.formatPath(resolveDotnetExecutable(driverDir, args(1))) - otherArgs = args.slice(2, args.length) - } else { - // Reuse windows-specific formatting in PythonRunner. - dotnetExecutable = PythonRunner.formatPath(args(0)) - otherArgs = args.slice(1, args.length) - } - } else { - otherArgs = args.slice(1, args.length) - } - - val processParameters = new java.util.ArrayList[String] - processParameters.add(dotnetExecutable) - otherArgs.foreach(arg => processParameters.add(arg)) - - logInfo(s"Starting DotnetBackend with $dotnetExecutable.") - - // Time to wait for DotnetBackend to initialize in seconds. - val backendTimeout = sys.env.getOrElse("DOTNETBACKEND_TIMEOUT", "120").toInt - - // Launch a DotnetBackend server for the .NET process to connect to; this will let it see our - // Java system properties etc. - val dotnetBackend = new DotnetBackend() - val initialized = new Semaphore(0) - val dotnetBackendThread = new Thread("DotnetBackend") { - override def run() { - // need to get back dotnetBackendPortNumber because if the value passed to init is 0 - // the port number is dynamically assigned in the backend - dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber) - logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber") - initialized.release() - dotnetBackend.run() - } - } - - dotnetBackendThread.start() - - if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { - if (!runInDebugMode) { - var returnCode = -1 - var process: Process = null - try { - val builder = new ProcessBuilder(processParameters) - val env = builder.environment() - env.put("DOTNETBACKEND_PORT", dotnetBackendPortNumber.toString) - - for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) { - env.put(key, value) - logInfo(s"Adding key=$key and value=$value to environment") - } - builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize - process = builder.start() - - // Redirect stdin of JVM process to stdin of .NET process. - new RedirectThread(System.in, process.getOutputStream, "redirect JVM input").start() - // Redirect stdout and stderr of .NET process. - new RedirectThread(process.getInputStream, System.out, "redirect .NET stdout").start() - new RedirectThread(process.getErrorStream, System.out, "redirect .NET stderr").start() - - process.waitFor() - } catch { - case t: Throwable => - logThrowable(t) - } finally { - returnCode = closeDotnetProcess(process) - closeBackend(dotnetBackend) - } - - if (returnCode != 0) { - throw new SparkUserAppException(returnCode) - } else { - logInfo(s".NET application exited successfully") - } - // TODO: The following is causing the following error: - // INFO ApplicationMaster: Final app status: FAILED, exitCode: 16, - // (reason: Shutdown hook called before final status was reported.) - // DotnetUtils.exit(returnCode) - } else { - // scalastyle:off println - println("***********************************************************************") - println("* .NET Backend running debug mode. Press enter to exit *") - println("***********************************************************************") - // scalastyle:on println - - StdIn.readLine() - closeBackend(dotnetBackend) - DotnetUtils.exit(0) - } - } else { - logError(s"DotnetBackend did not initialize in $backendTimeout seconds") - DotnetUtils.exit(-1) - } - } - - // When the executable is downloaded as part of zip file, check if the file exists - // after zip file is unzipped under the given dir. Once it is found, change the - // permission to executable (only for Unix systems, since the zip file may have been - // created under Windows. Finally, the absolute path for the executable is returned. - private def resolveDotnetExecutable(dir: File, dotnetExecutable: String): String = { - val path = Paths.get(dir.getAbsolutePath, dotnetExecutable) - val resolvedExecutable = if (Files.isRegularFile(path)) { - path.toAbsolutePath.toString - } else { - Files - .walk(FileSystems.getDefault.getPath(dir.getAbsolutePath)) - .iterator() - .asScala - .find(path => Files.isRegularFile(path) && path.getFileName.toString == dotnetExecutable) match { - case Some(path) => path.toAbsolutePath.toString - case None => - throw new IllegalArgumentException( - s"Failed to find $dotnetExecutable under ${dir.getAbsolutePath}") - } - } - - if (DotnetUtils.supportPosix) { - Files.setPosixFilePermissions( - Paths.get(resolvedExecutable), - PosixFilePermissions.fromString("rwxr-xr-x")) - } - - resolvedExecutable - } - - /** - * Download HDFS file into the supplied directory and return its local path. - * Will throw an exception if there are errors during downloading. - */ - private def downloadDriverFile(hdfsFilePath: String, driverDir: String): File = { - val sparkConf = new SparkConf() - val filePath = new Path(hdfsFilePath) - - val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) - val jarFileName = filePath.getName - val localFile = new File(driverDir, jarFileName) - - if (!localFile.exists()) { // May already exist if running multiple workers on one node - logInfo(s"Copying user file $filePath to $driverDir") - Utils.fetchFile( - hdfsFilePath, - new File(driverDir), - sparkConf, - new SecurityManager(sparkConf), - hadoopConf, - System.currentTimeMillis(), - useCache = false) - } - - if (!localFile.exists()) { - throw new Exception(s"Did not see expected $jarFileName in $driverDir") - } - - localFile - } - - private def closeBackend(dotnetBackend: DotnetBackend): Unit = { - logInfo("Closing DotnetBackend") - dotnetBackend.close() - } - - private def closeDotnetProcess(dotnetProcess: Process): Int = { - if (dotnetProcess == null) { - return -1 - } else if (!dotnetProcess.isAlive) { - return dotnetProcess.exitValue() - } - - // Try to (gracefully on Linux) kill the process and resort to force if interrupted - var returnCode = -1 - logInfo("Closing .NET process") - try { - dotnetProcess.destroy() - returnCode = dotnetProcess.waitFor() - } catch { - case _: InterruptedException => - logInfo( - "Thread interrupted while waiting for graceful close. Forcefully closing .NET process") - returnCode = dotnetProcess.destroyForcibly().waitFor() - case t: Throwable => - logThrowable(t) - } - - returnCode - } - - private def initializeSettings(args: Array[String]): (Boolean, Int) = { - val runInDebugMode = (args.length == 1 || args.length == 2) && args(0).equalsIgnoreCase( - "debug") - var portNumber = 0 - if (runInDebugMode) { - if (args.length == 1) { - portNumber = DEBUG_PORT - } else if (args.length == 2) { - portNumber = Integer.parseInt(args(1)) - } - } - - (runInDebugMode, portNumber) - } - - private def logThrowable(throwable: Throwable): Unit = - logError(s"${throwable.getMessage} \n ${throwable.getStackTrace.mkString("\n")}") -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala deleted file mode 100644 index ad54af17f..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.internal.config.dotnet - -import org.apache.spark.internal.config.ConfigBuilder - -private[spark] object Dotnet { - val DOTNET_NUM_BACKEND_THREADS = ConfigBuilder("spark.dotnet.numDotnetBackendThreads").intConf - .createWithDefault(10) - - val DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK = - ConfigBuilder("spark.dotnet.ignoreSparkPatchVersionCheck").booleanConf - .createWithDefault(false) -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala deleted file mode 100644 index b5e97289a..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.sql.api.dotnet - -import java.util.{List => JList, Map => JMap} - -import org.apache.spark.api.python.{PythonAccumulatorV2, PythonBroadcast, PythonFunction} -import org.apache.spark.broadcast.Broadcast - -object SQLUtils { - - /** - * Exposes createPythonFunction to the .NET client to enable registering UDFs. - */ - def createPythonFunction( - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVersion: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: PythonAccumulatorV2): PythonFunction = { - - PythonFunction( - command, - envVars, - pythonIncludes, - pythonExec, - pythonVersion, - broadcastVars, - accumulator) - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/sql/test/TestUtils.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/sql/test/TestUtils.scala deleted file mode 100644 index 1cd45aa95..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/sql/test/TestUtils.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.sql.test - -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.streaming.MemoryStream - -object TestUtils { - - /** - * Helper method to create typed MemoryStreams intended for use in unit tests. - * @param sqlContext The SQLContext. - * @param streamType The type of memory stream to create. This string is the `Name` - * property of the dotnet type. - * @return A typed MemoryStream. - */ - def createMemoryStream(implicit sqlContext: SQLContext, streamType: String): MemoryStream[_] = { - import sqlContext.implicits._ - - streamType match { - case "Int32" => MemoryStream[Int] - case "String" => MemoryStream[String] - case _ => throw new Exception(s"$streamType not supported") - } - } -} diff --git a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/util/dotnet/Utils.scala b/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/util/dotnet/Utils.scala deleted file mode 100644 index 523d63900..000000000 --- a/src/scala/microsoft-spark-2-3/src/main/scala/org/apache/spark/util/dotnet/Utils.scala +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.util.dotnet - -import java.io._ -import java.nio.file.attribute.PosixFilePermission -import java.nio.file.attribute.PosixFilePermission._ -import java.nio.file.{FileSystems, Files} -import java.util.{Timer, TimerTask} - -import org.apache.commons.compress.archivers.zip.{ZipArchiveEntry, ZipArchiveOutputStream, ZipFile} -import org.apache.commons.io.{FileUtils, IOUtils} -import org.apache.spark.SparkConf -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK - -import scala.collection.JavaConverters._ -import scala.collection.Set - -/** - * Utility methods. - */ -object Utils extends Logging { - private val posixFilePermissions = Array( - OWNER_READ, - OWNER_WRITE, - OWNER_EXECUTE, - GROUP_READ, - GROUP_WRITE, - GROUP_EXECUTE, - OTHERS_READ, - OTHERS_WRITE, - OTHERS_EXECUTE) - - val supportPosix: Boolean = - FileSystems.getDefault.supportedFileAttributeViews().contains("posix") - - /** - * Compress all files under given directory into one zip file and drop it to the target directory - * - * @param sourceDir source directory to zip - * @param targetZipFile target zip file - */ - def zip(sourceDir: File, targetZipFile: File): Unit = { - var fos: FileOutputStream = null - var zos: ZipArchiveOutputStream = null - try { - fos = new FileOutputStream(targetZipFile) - zos = new ZipArchiveOutputStream(fos) - - val sourcePath = sourceDir.toPath - FileUtils.listFiles(sourceDir, null, true).asScala.foreach { file => - var in: FileInputStream = null - try { - val path = file.toPath - val entry = new ZipArchiveEntry(sourcePath.relativize(path).toString) - if (supportPosix) { - entry.setUnixMode( - permissionsToMode(Files.getPosixFilePermissions(path).asScala) - | (if (entry.getName.endsWith(".exe")) 0x1ED else 0x1A4)) - } else if (entry.getName.endsWith(".exe")) { - entry.setUnixMode(0x1ED) // 755 - } else { - entry.setUnixMode(0x1A4) // 644 - } - zos.putArchiveEntry(entry) - - in = new FileInputStream(file) - IOUtils.copy(in, zos) - zos.closeArchiveEntry() - } finally { - IOUtils.closeQuietly(in) - } - } - } finally { - IOUtils.closeQuietly(zos) - IOUtils.closeQuietly(fos) - } - } - - /** - * Unzip a file to the given directory - * - * @param file file to be unzipped - * @param targetDir target directory - */ - def unzip(file: File, targetDir: File): Unit = { - var zipFile: ZipFile = null - try { - targetDir.mkdirs() - zipFile = new ZipFile(file) - zipFile.getEntries.asScala.foreach { entry => - val targetFile = new File(targetDir, entry.getName) - - if (targetFile.exists()) { - logWarning( - s"Target file/directory $targetFile already exists. Skip it for now. " + - s"Make sure this is expected.") - } else { - if (entry.isDirectory) { - targetFile.mkdirs() - } else { - targetFile.getParentFile.mkdirs() - val input = zipFile.getInputStream(entry) - val output = new FileOutputStream(targetFile) - IOUtils.copy(input, output) - IOUtils.closeQuietly(input) - IOUtils.closeQuietly(output) - if (supportPosix) { - val permissions = modeToPermissions(entry.getUnixMode) - // When run in Unix system, permissions will be empty, thus skip - // setting the empty permissions (which will empty the previous permissions). - if (permissions.nonEmpty) { - Files.setPosixFilePermissions(targetFile.toPath, permissions.asJava) - } - } - } - } - } - } catch { - case e: Exception => logError("exception caught during decompression:" + e) - } finally { - ZipFile.closeQuietly(zipFile) - } - } - - /** - * Exits the JVM, trying to do it nicely, otherwise doing it nastily. - * - * @param status the exit status, zero for OK, non-zero for error - * @param maxDelayMillis the maximum delay in milliseconds - */ - def exit(status: Int, maxDelayMillis: Long) { - try { - logInfo(s"Utils.exit() with status: $status, maxDelayMillis: $maxDelayMillis") - - // setup a timer, so if nice exit fails, the nasty exit happens - val timer = new Timer() - timer.schedule(new TimerTask() { - - override def run() { - Runtime.getRuntime.halt(status) - } - }, maxDelayMillis) - // try to exit nicely - System.exit(status); - } catch { - // exit nastily if we have a problem - case _: Throwable => Runtime.getRuntime.halt(status) - } finally { - // should never get here - Runtime.getRuntime.halt(status) - } - } - - /** - * Exits the JVM, trying to do it nicely, wait 1 second - * - * @param status the exit status, zero for OK, non-zero for error - */ - def exit(status: Int): Unit = { - exit(status, 1000) - } - - /** - * Normalize the Spark version by taking the first three numbers. - * For example: - * x.y.z => x.y.z - * x.y.z.xxx.yyy => x.y.z - * x.y => x.y - * x.y.z => x.y.z - * - * @param version the Spark version to normalize - * @return Normalized Spark version. - */ - def normalizeSparkVersion(version: String): String = { - version - .split('.') - .take(3) - .zipWithIndex - .map({ - case (element, index) => { - index match { - case 2 => element.split("\\D+").lift(0).getOrElse("") - case _ => element - } - } - }) - .mkString(".") - } - - /** - * Validates the normalized spark version by verifying: - * - Spark version starts with sparkMajorMinorVersionPrefix. - * - If ignoreSparkPatchVersion is - * - true: valid - * - false: check if the spark version is in supportedSparkVersions. - * @param ignoreSparkPatchVersion Ignore spark patch version. - * @param sparkVersion The spark version. - * @param normalizedSparkVersion: The normalized spark version. - * @param supportedSparkMajorMinorVersionPrefix The spark major and minor version to validate against. - * @param supportedSparkVersions The set of supported spark versions. - */ - def validateSparkVersions( - ignoreSparkPatchVersion: Boolean, - sparkVersion: String, - normalizedSparkVersion: String, - supportedSparkMajorMinorVersionPrefix: String, - supportedSparkVersions: Set[String]): Unit = { - if (!normalizedSparkVersion.startsWith(s"$supportedSparkMajorMinorVersionPrefix.")) { - throw new IllegalArgumentException( - s"Unsupported spark version used: '$sparkVersion'. " + - s"Normalized spark version used: '$normalizedSparkVersion'. " + - s"Supported spark major.minor version: '$supportedSparkMajorMinorVersionPrefix'.") - } else if (ignoreSparkPatchVersion) { - logWarning( - s"Ignoring spark patch version. Spark version used: '$sparkVersion'. " + - s"Normalized spark version used: '$normalizedSparkVersion'. " + - s"Spark major.minor prefix used: '$supportedSparkMajorMinorVersionPrefix'.") - } else if (!supportedSparkVersions(normalizedSparkVersion)) { - val supportedVersions = supportedSparkVersions.toSeq.sorted.mkString(", ") - throw new IllegalArgumentException( - s"Unsupported spark version used: '$sparkVersion'. " + - s"Normalized spark version used: '$normalizedSparkVersion'. " + - s"Supported versions: '$supportedVersions'.") - } - } - - private[spark] def listZipFileEntries(file: File): Array[String] = { - var zipFile: ZipFile = null - try { - zipFile = new ZipFile(file) - zipFile.getEntries.asScala.map(_.getName).toArray - } finally { - ZipFile.closeQuietly(zipFile) - } - } - - private[this] def permissionsToMode(permissions: Set[PosixFilePermission]): Int = { - posixFilePermissions.foldLeft(0) { (mode, perm) => - (mode << 1) | (if (permissions.contains(perm)) 1 else 0) - } - } - - private[this] def modeToPermissions(mode: Int): Set[PosixFilePermission] = { - posixFilePermissions.zipWithIndex - .filter { case (_, i) => (mode & (0x100 >>> i)) != 0 } - .map(_._1) - .toSet - } -} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala deleted file mode 100644 index 6fe7e840f..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import org.junit.Assert._ -import org.junit.{After, Before, Test} - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} - -@Test -class DotnetBackendHandlerTest { - private var backend: DotnetBackend = _ - private var tracker: JVMObjectTracker = _ - private var handler: DotnetBackendHandler = _ - - @Before - def before(): Unit = { - backend = new DotnetBackend - tracker = new JVMObjectTracker - handler = new DotnetBackendHandler(backend, tracker) - } - - @After - def after(): Unit = { - backend.close() - } - - @Test - def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = { - val message = givenMessage(m => { - val serDe = new SerDe(null) - m.writeBoolean(true) // static method - serDe.writeInt(m, 1) // processId - serDe.writeInt(m, 1) // threadId - serDe.writeString(m, "DotnetHandler") // class name - serDe.writeString(m, "connectCallback") // command (method) name - m.writeInt(2) // number of arguments - m.writeByte('c') // 1st argument type (string) - serDe.writeString(m, "127.0.0.1") // 1st argument value (host) - m.writeByte('i') // 2nd argument type (integer) - m.writeInt(0) // 2nd argument value (port) - }) - - val payload = handler.handleBackendRequest(message) - val reply = new DataInputStream(new ByteArrayInputStream(payload)) - - assertEquals( - "status code must be successful.", 0, reply.readInt()) - assertEquals('n', reply.readByte()) - } - - private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = { - val buffer = new ByteArrayOutputStream() - func(new DataOutputStream(buffer)) - buffer.toByteArray - } -} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala deleted file mode 100644 index 1abf10e20..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import org.junit.Assert._ -import org.junit.function.ThrowingRunnable -import org.junit.{After, Before, Test} - -import java.net.InetAddress - -@Test -class DotnetBackendTest { - private var backend: DotnetBackend = _ - - @Before - def before(): Unit = { - backend = new DotnetBackend - } - - @After - def after(): Unit = { - backend.close() - } - - @Test - def shouldNotResetCallbackClient(): Unit = { - // Specifying port = 0 to select port dynamically. - backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) - - assertTrue(backend.callbackClient.isDefined) - assertThrows( - classOf[Exception], - new ThrowingRunnable { - override def run(): Unit = { - backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) - } - }) - } -} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala deleted file mode 100644 index 8c6e51608..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import java.io.DataInputStream - -private[dotnet] object Extensions { - implicit class DataInputStreamExt(stream: DataInputStream) { - def readNBytes(n: Int): Array[Byte] = { - val buf = new Array[Byte](n) - stream.readFully(buf) - buf - } - } -} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala deleted file mode 100644 index 43ae79005..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import org.junit.Test - -@Test -class JVMObjectTrackerTest { - - @Test - def shouldReleaseAllReferences(): Unit = { - val tracker = new JVMObjectTracker - val firstId = tracker.put(new Object) - val secondId = tracker.put(new Object) - val thirdId = tracker.put(new Object) - - tracker.clear() - - assert(tracker.get(firstId).isEmpty) - assert(tracker.get(secondId).isEmpty) - assert(tracker.get(thirdId).isEmpty) - } - - @Test - def shouldResetCounter(): Unit = { - val tracker = new JVMObjectTracker - val firstId = tracker.put(new Object) - val secondId = tracker.put(new Object) - - tracker.clear() - - val thirdId = tracker.put(new Object) - - assert(firstId.equals("1")) - assert(secondId.equals("2")) - assert(thirdId.equals("1")) - } -} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala deleted file mode 100644 index 78ca905bb..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala +++ /dev/null @@ -1,386 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.api.dotnet - -import org.apache.spark.api.dotnet.Extensions._ -import org.apache.spark.sql.Row -import org.junit.Assert._ -import org.junit.function.ThrowingRunnable -import org.junit.{Before, Test} - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import java.sql.Date -import scala.collection.JavaConverters.{mapAsJavaMapConverter, seqAsJavaListConverter} - -@Test -class SerDeTest { - private var serDe: SerDe = _ - private var tracker: JVMObjectTracker = _ - - @Before - def before(): Unit = { - tracker = new JVMObjectTracker - serDe = new SerDe(tracker) - } - - @Test - def shouldReadNull(): Unit = { - val input = givenInput(in => { - in.writeByte('n') - }) - - assertEquals(null, serDe.readObject(input)) - } - - @Test - def shouldThrowForUnsupportedTypes(): Unit = { - val input = givenInput(in => { - in.writeByte('_') - }) - - assertThrows( - classOf[IllegalArgumentException], - new ThrowingRunnable { - override def run(): Unit = { - serDe.readObject(input) - } - }) - } - - @Test - def shouldReadInteger(): Unit = { - val input = givenInput(in => { - in.writeByte('i') - in.writeInt(42) - }) - - assertEquals(42, serDe.readObject(input)) - } - - @Test - def shouldReadLong(): Unit = { - val input = givenInput(in => { - in.writeByte('g') - in.writeLong(42) - }) - - assertEquals(42L, serDe.readObject(input)) - } - - @Test - def shouldReadDouble(): Unit = { - val input = givenInput(in => { - in.writeByte('d') - in.writeDouble(42.42) - }) - - assertEquals(42.42, serDe.readObject(input)) - } - - @Test - def shouldReadBoolean(): Unit = { - val input = givenInput(in => { - in.writeByte('b') - in.writeBoolean(true) - }) - - assertEquals(true, serDe.readObject(input)) - } - - @Test - def shouldReadString(): Unit = { - val payload = "Spark Dotnet" - val input = givenInput(in => { - in.writeByte('c') - in.writeInt(payload.getBytes("UTF-8").length) - in.write(payload.getBytes("UTF-8")) - }) - - assertEquals(payload, serDe.readObject(input)) - } - - @Test - def shouldReadMap(): Unit = { - val input = givenInput(in => { - in.writeByte('e') // map type descriptor - in.writeInt(3) // size - in.writeByte('i') // key type - in.writeInt(3) // number of keys - in.writeInt(11) // first key - in.writeInt(22) // second key - in.writeInt(33) // third key - in.writeInt(3) // number of values - in.writeByte('b') // first value type - in.writeBoolean(true) // first value - in.writeByte('d') // second value type - in.writeDouble(42.42) // second value - in.writeByte('n') // third type & value - }) - - assertEquals( - Map( - 11 -> true, - 22 -> 42.42, - 33 -> null).asJava, - serDe.readObject(input)) - } - - @Test - def shouldReadEmptyMap(): Unit = { - val input = givenInput(in => { - in.writeByte('e') // map type descriptor - in.writeInt(0) // size - }) - - assertEquals(Map().asJava, serDe.readObject(input)) - } - - @Test - def shouldReadBytesArray(): Unit = { - val input = givenInput(in => { - in.writeByte('r') // byte array type descriptor - in.writeInt(3) // length - in.write(Array[Byte](1, 2, 3)) // payload - }) - - assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]]) - } - - @Test - def shouldReadEmptyBytesArray(): Unit = { - val input = givenInput(in => { - in.writeByte('r') // byte array type descriptor - in.writeInt(0) // length - }) - - assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]]) - } - - @Test - def shouldReadEmptyList(): Unit = { - val input = givenInput(in => { - in.writeByte('l') // type descriptor - in.writeByte('i') // element type - in.writeInt(0) // length - }) - - assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]]) - } - - @Test - def shouldReadList(): Unit = { - val input = givenInput(in => { - in.writeByte('l') // type descriptor - in.writeByte('b') // element type - in.writeInt(3) // length - in.writeBoolean(true) - in.writeBoolean(false) - in.writeBoolean(true) - }) - - assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]]) - } - - @Test - def shouldThrowWhenReadingListWithUnsupportedType(): Unit = { - val input = givenInput(in => { - in.writeByte('l') // type descriptor - in.writeByte('_') // unsupported element type - }) - - assertThrows( - classOf[IllegalArgumentException], - new ThrowingRunnable { - override def run(): Unit = { - serDe.readObject(input) - } - }) - } - - @Test - def shouldReadDate(): Unit = { - val input = givenInput(in => { - val date = "2020-12-31" - in.writeByte('D') // type descriptor - in.writeInt(date.getBytes("UTF-8").length) // date string size - in.write(date.getBytes("UTF-8")) - }) - - assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input)) - } - - @Test - def shouldReadObject(): Unit = { - val trackingObject = new Object - tracker.put(trackingObject) - val input = givenInput(in => { - val objectIndex = "1" - in.writeByte('j') // type descriptor - in.writeInt(objectIndex.getBytes("UTF-8").length) // size - in.write(objectIndex.getBytes("UTF-8")) - }) - - assertSame(trackingObject, serDe.readObject(input)) - } - - @Test - def shouldThrowWhenReadingNonTrackingObject(): Unit = { - val input = givenInput(in => { - val objectIndex = "42" - in.writeByte('j') // type descriptor - in.writeInt(objectIndex.getBytes("UTF-8").length) // size - in.write(objectIndex.getBytes("UTF-8")) - }) - - assertThrows( - classOf[NoSuchElementException], - new ThrowingRunnable { - override def run(): Unit = { - serDe.readObject(input) - } - }) - } - - @Test - def shouldReadSparkRows(): Unit = { - val input = givenInput(in => { - in.writeByte('R') // type descriptor - in.writeInt(2) // number of rows - in.writeInt(1) // number of elements in 1st row - in.writeByte('i') // type of 1st element in 1st row - in.writeInt(11) - in.writeInt(3) // number of elements in 2st row - in.writeByte('b') // type of 1st element in 2nd row - in.writeBoolean(true) - in.writeByte('d') // type of 2nd element in 2nd row - in.writeDouble(42.24) - in.writeByte('g') // type of 3nd element in 2nd row - in.writeLong(99) - }) - - assertEquals( - Seq( - Row.fromSeq(Seq(11)), - Row.fromSeq(Seq(true, 42.24, 99))).asJava, - serDe.readObject(input)) - } - - @Test - def shouldReadArrayOfObjects(): Unit = { - val input = givenInput(in => { - in.writeByte('O') // type descriptor - in.writeInt(2) // number of elements - in.writeByte('i') // type of 1st element - in.writeInt(42) - in.writeByte('b') // type of 2nd element - in.writeBoolean(true) - }) - - assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]]) - } - - @Test - def shouldWriteNull(): Unit = { - val in = whenOutput(out => { - serDe.writeObject(out, null) - serDe.writeObject(out, Unit) - }) - - assertEquals(in.readByte(), 'n') - assertEquals(in.readByte(), 'n') - assertEndOfStream(in) - } - - @Test - def shouldWriteString(): Unit = { - val sparkDotnet = "Spark Dotnet" - val in = whenOutput(out => { - serDe.writeObject(out, sparkDotnet) - }) - - assertEquals(in.readByte(), 'c') // object type - assertEquals(in.readInt(), sparkDotnet.length) // length - assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) - assertEndOfStream(in) - } - - @Test - def shouldWritePrimitiveTypes(): Unit = { - val in = whenOutput(out => { - serDe.writeObject(out, 42.24f.asInstanceOf[Object]) - serDe.writeObject(out, 42L.asInstanceOf[Object]) - serDe.writeObject(out, 42.asInstanceOf[Object]) - serDe.writeObject(out, true.asInstanceOf[Object]) - }) - - assertEquals(in.readByte(), 'd') - assertEquals(in.readDouble(), 42.24F, 0.000001) - assertEquals(in.readByte(), 'g') - assertEquals(in.readLong(), 42L) - assertEquals(in.readByte(), 'i') - assertEquals(in.readInt(), 42) - assertEquals(in.readByte(), 'b') - assertEquals(in.readBoolean(), true) - assertEndOfStream(in) - } - - @Test - def shouldWriteDate(): Unit = { - val date = "2020-12-31" - val in = whenOutput(out => { - serDe.writeObject(out, Date.valueOf(date)) - }) - - assertEquals(in.readByte(), 'D') // type - assertEquals(in.readInt(), 10) // size - assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content - } - - @Test - def shouldWriteCustomObjects(): Unit = { - val customObject = new Object - val in = whenOutput(out => { - serDe.writeObject(out, customObject) - }) - - assertEquals(in.readByte(), 'j') - assertEquals(in.readInt(), 1) - assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) - assertSame(tracker.get("1").get, customObject) - } - - @Test - def shouldWriteArrayOfCustomObjects(): Unit = { - val payload = Array(new Object, new Object) - val in = whenOutput(out => { - serDe.writeObject(out, payload) - }) - - assertEquals(in.readByte(), 'l') // array type - assertEquals(in.readByte(), 'j') // type of element in array - assertEquals(in.readInt(), 2) // array length - assertEquals(in.readInt(), 1) // size of 1st element's identifiers - assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element - assertEquals(in.readInt(), 1) // size of 2nd element's identifier - assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element - assertSame(tracker.get("1").get, payload(0)) - assertSame(tracker.get("2").get, payload(1)) - } - - private def givenInput(func: DataOutputStream => Unit): DataInputStream = { - val buffer = new ByteArrayOutputStream() - val out = new DataOutputStream(buffer) - func(out) - new DataInputStream(new ByteArrayInputStream(buffer.toByteArray)) - } - - private def whenOutput = givenInput _ - - private def assertEndOfStream(in: DataInputStream): Unit = { - assertEquals(-1, in.read()) - } -} diff --git a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala b/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala deleted file mode 100644 index b21e96ca8..000000000 --- a/src/scala/microsoft-spark-2-3/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the .NET Foundation under one or more agreements. - * The .NET Foundation licenses this file to you under the MIT license. - * See the LICENSE file in the project root for more information. - */ - -package org.apache.spark.util.dotnet - -import org.apache.spark.SparkConf -import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK -import org.junit.Assert.{assertEquals, assertThrows} -import org.junit.Test -import org.junit.function.ThrowingRunnable - -@Test -class UtilsTest { - - @Test - def shouldIgnorePatchVersion(): Unit = { - val sparkVersion = "2.3.5" - val sparkMajorMinorVersionPrefix = "2.3" - val supportedSparkVersions = Set[String]("2.3.0", "2.3.1", "2.3.2", "2.3.3", "2.3.4") - - Utils.validateSparkVersions( - true, - sparkVersion, - Utils.normalizeSparkVersion(sparkVersion), - sparkMajorMinorVersionPrefix, - supportedSparkVersions) - } - - @Test - def shouldThrowForUnsupportedVersion(): Unit = { - val sparkVersion = "2.3.5" - val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion) - val sparkMajorMinorVersionPrefix = "2.3" - val supportedSparkVersions = Set[String]("2.3.0", "2.3.1", "2.3.2", "2.3.3", "2.3.4") - - val exception = assertThrows( - classOf[IllegalArgumentException], - new ThrowingRunnable { - override def run(): Unit = { - Utils.validateSparkVersions( - false, - sparkVersion, - normalizedSparkVersion, - sparkMajorMinorVersionPrefix, - supportedSparkVersions) - } - }) - - assertEquals( - s"Unsupported spark version used: '$sparkVersion'. " + - s"Normalized spark version used: '$normalizedSparkVersion'. " + - s"Supported versions: '${supportedSparkVersions.toSeq.sorted.mkString(", ")}'.", - exception.getMessage) - } - - @Test - def shouldThrowForUnsupportedMajorMinorVersion(): Unit = { - val sparkVersion = "2.4.4" - val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion) - val sparkMajorMinorVersionPrefix = "2.3" - val supportedSparkVersions = Set[String]("2.3.0", "2.3.1", "2.3.2", "2.3.3", "2.3.4") - - val exception = assertThrows( - classOf[IllegalArgumentException], - new ThrowingRunnable { - override def run(): Unit = { - Utils.validateSparkVersions( - false, - sparkVersion, - normalizedSparkVersion, - sparkMajorMinorVersionPrefix, - supportedSparkVersions) - } - }) - - assertEquals( - s"Unsupported spark version used: '$sparkVersion'. " + - s"Normalized spark version used: '$normalizedSparkVersion'. " + - s"Supported spark major.minor version: '$sparkMajorMinorVersionPrefix'.", - exception.getMessage) - } -} diff --git a/src/scala/pom.xml b/src/scala/pom.xml index d564fe621..fbef521a3 100644 --- a/src/scala/pom.xml +++ b/src/scala/pom.xml @@ -11,7 +11,6 @@ - microsoft-spark-2-3 microsoft-spark-2-4 microsoft-spark-3-0 microsoft-spark-3-1 From c4495fc5f0a9f5f5e9d26823db55894503c63b62 Mon Sep 17 00:00:00 2001 From: Steve Suh Date: Thu, 8 Apr 2021 12:26:57 -0700 Subject: [PATCH 2/8] remove 2.3 logic --- .../AssemblyKernelExtension.cs | 6 ++-- .../Microsoft.Spark.E2ETest/SparkSettings.cs | 4 +-- .../PayloadWriter.cs | 14 --------- .../TestData.cs | 4 --- .../Processor/BroadcastVariableProcessor.cs | 2 +- .../Processor/CommandProcessor.cs | 19 ------------ .../Processor/TaskContextProcessor.cs | 9 ------ .../Utils/SettingUtils.cs | 4 +-- src/csharp/Microsoft.Spark/Broadcast.cs | 30 +++---------------- src/csharp/Microsoft.Spark/Sql/DataFrame.cs | 12 ++------ src/csharp/Microsoft.Spark/Versions.cs | 4 --- 11 files changed, 15 insertions(+), 93 deletions(-) diff --git a/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/AssemblyKernelExtension.cs b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/AssemblyKernelExtension.cs index a99e6ee0b..da2521283 100644 --- a/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/AssemblyKernelExtension.cs +++ b/src/csharp/Extensions/Microsoft.Spark.Extensions.DotNet.Interactive/AssemblyKernelExtension.cs @@ -145,10 +145,10 @@ private bool IsPathValid(string path) } Version version = SparkEnvironment.SparkVersion; - return (version.Major, version.Minor, version.Build) switch + return version.Major switch { - (2, _, _) => false, - (3, 0, _) => true, + 2 => false, + 3 => true, _ => throw new NotSupportedException($"Spark {version} not supported.") }; } diff --git a/src/csharp/Microsoft.Spark.E2ETest/SparkSettings.cs b/src/csharp/Microsoft.Spark.E2ETest/SparkSettings.cs index a568586d3..cf8e0bb43 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/SparkSettings.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/SparkSettings.cs @@ -32,11 +32,11 @@ private static void InitSparkHome() private static void InitVersion() { // First line of the RELEASE file under SPARK_HOME will be something similar to: - // Spark 2.3.2 built for Hadoop 2.7.3 + // Spark 2.4.0 built for Hadoop 2.7.3 string firstLine = File.ReadLines($"{SparkHome}{Path.DirectorySeparatorChar}RELEASE").First(); - // Grab "2.3.2" from "Spark 2.3.2 built for Hadoop 2.7.3" + // Grab "2.4.0" from "Spark 2.4.0 built for Hadoop 2.7.3" string versionStr = firstLine.Split(' ')[1]; // Strip anything below version number. diff --git a/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs b/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs index 0e5745f89..d556b6e89 100644 --- a/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs +++ b/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs @@ -343,20 +343,6 @@ internal PayloadWriter Create(Version version = null) switch (version.ToString()) { - case Versions.V2_3_0: - case Versions.V2_3_1: - return new PayloadWriter( - version, - new TaskContextWriterV2_3_X(), - new BroadcastVariableWriterV2_3_0(), - new CommandWriterV2_3_X()); - case Versions.V2_3_2: - case Versions.V2_3_3: - return new PayloadWriter( - version, - new TaskContextWriterV2_3_X(), - new BroadcastVariableWriterV2_3_2(), - new CommandWriterV2_3_X()); case Versions.V2_4_0: return new PayloadWriter( version, diff --git a/src/csharp/Microsoft.Spark.Worker.UnitTest/TestData.cs b/src/csharp/Microsoft.Spark.Worker.UnitTest/TestData.cs index 798d82d56..99865530e 100644 --- a/src/csharp/Microsoft.Spark.Worker.UnitTest/TestData.cs +++ b/src/csharp/Microsoft.Spark.Worker.UnitTest/TestData.cs @@ -16,10 +16,6 @@ internal static class TestData public static IEnumerable VersionData() => new List { - new object[] { Versions.V2_3_0 }, - new object[] { Versions.V2_3_1 }, - new object[] { Versions.V2_3_2 }, - new object[] { Versions.V2_3_3 }, new object[] { Versions.V2_4_0 }, new object[] { Versions.V3_0_0 }, }; diff --git a/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs b/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs index e3bc16df6..40a9ffae3 100644 --- a/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs +++ b/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs @@ -30,7 +30,7 @@ internal BroadcastVariables Process(Stream stream) var broadcastVars = new BroadcastVariables(); ISocketWrapper socket = null; - if (_version >= new Version(Versions.V2_3_2)) + if (_version >= new Version(Versions.V2_4_0)) { broadcastVars.DecryptionServerNeeded = SerDe.ReadBool(stream); } diff --git a/src/csharp/Microsoft.Spark.Worker/Processor/CommandProcessor.cs b/src/csharp/Microsoft.Spark.Worker/Processor/CommandProcessor.cs index 1619ea6aa..ebb25c650 100644 --- a/src/csharp/Microsoft.Spark.Worker/Processor/CommandProcessor.cs +++ b/src/csharp/Microsoft.Spark.Worker/Processor/CommandProcessor.cs @@ -97,7 +97,6 @@ private static SqlCommand[] ReadSqlCommands( return (version.Major, version.Minor) switch { - (2, 3) => SqlCommandProcessorV2_3_X.Process(evalType, stream), (2, 4) => SqlCommandProcessorV2_4_X.Process(evalType, stream), (3, _) => SqlCommandProcessorV2_4_X.Process(evalType, stream), _ => throw new NotSupportedException($"Spark {version} not supported.") @@ -224,24 +223,6 @@ private static SqlCommand[] ReadSqlCommands( return commands; } - private static class SqlCommandProcessorV2_3_X - { - internal static SqlCommand[] Process(PythonEvalType evalType, Stream stream) - { - SqlCommand[] sqlCommands = ReadSqlCommands(evalType, stream); - - if ((evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF) || - (evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)) - { - // Reads the timezone information. This is not going to be used until - // timestamp column is supported in Arrow. - SerDe.ReadString(stream); - } - - return sqlCommands; - } - } - private static class SqlCommandProcessorV2_4_X { internal static SqlCommand[] Process(PythonEvalType evalType, Stream stream) diff --git a/src/csharp/Microsoft.Spark.Worker/Processor/TaskContextProcessor.cs b/src/csharp/Microsoft.Spark.Worker/Processor/TaskContextProcessor.cs index 45da43f2c..0db9768d2 100644 --- a/src/csharp/Microsoft.Spark.Worker/Processor/TaskContextProcessor.cs +++ b/src/csharp/Microsoft.Spark.Worker/Processor/TaskContextProcessor.cs @@ -21,7 +21,6 @@ internal TaskContext Process(Stream stream) { return (_version.Major, _version.Minor) switch { - (2, 3) => TaskContextProcessorV2_3_X.Process(stream), (2, 4) => TaskContextProcessorV2_4_X.Process(stream), (3, _) => TaskContextProcessorV3_0_X.Process(stream), _ => throw new NotSupportedException($"Spark {_version} not supported.") @@ -74,14 +73,6 @@ private static void ReadTaskContextResources(Stream stream) } } - private static class TaskContextProcessorV2_3_X - { - internal static TaskContext Process(Stream stream) - { - return ReadTaskContext(stream); - } - } - private static class TaskContextProcessorV2_4_X { internal static TaskContext Process(Stream stream) diff --git a/src/csharp/Microsoft.Spark.Worker/Utils/SettingUtils.cs b/src/csharp/Microsoft.Spark.Worker/Utils/SettingUtils.cs index c48f1618a..a1086be44 100644 --- a/src/csharp/Microsoft.Spark.Worker/Utils/SettingUtils.cs +++ b/src/csharp/Microsoft.Spark.Worker/Utils/SettingUtils.cs @@ -14,14 +14,14 @@ internal static class SettingUtils { internal static string GetWorkerFactorySecret(Version version) { - return (version >= new Version(Versions.V2_3_1)) ? + return (version >= new Version(Versions.V2_4_0)) ? GetEnvironmentVariable("PYTHON_WORKER_FACTORY_SECRET") : null; } internal static int GetWorkerFactoryPort(Version version) { - string portStr = (version >= new Version(Versions.V2_3_1)) ? + string portStr = (version >= new Version(Versions.V2_4_0)) ? GetEnvironmentVariable("PYTHON_WORKER_FACTORY_PORT") : Console.ReadLine(); diff --git a/src/csharp/Microsoft.Spark/Broadcast.cs b/src/csharp/Microsoft.Spark/Broadcast.cs index f6e2697b0..d6529eb4a 100644 --- a/src/csharp/Microsoft.Spark/Broadcast.cs +++ b/src/csharp/Microsoft.Spark/Broadcast.cs @@ -125,43 +125,21 @@ private JvmObjectReference CreateBroadcast(SparkContext sc, T value) Version version = SparkEnvironment.SparkVersion; return (version.Major, version.Minor) switch { - (2, 3) when version.Build == 0 || version.Build == 1 => - CreateBroadcast_V2_3_1_AndBelow(javaSparkContext, value), - (2, 3) => CreateBroadcast_V2_3_2_AndAbove(javaSparkContext, sc, value), - (2, 4) => CreateBroadcast_V2_3_2_AndAbove(javaSparkContext, sc, value), - (3, _) => CreateBroadcast_V2_3_2_AndAbove(javaSparkContext, sc, value), + (2, 4) => CreateBroadcast_V2_4_X(javaSparkContext, sc, value), + (3, _) => CreateBroadcast_V2_4_X(javaSparkContext, sc, value), _ => throw new NotSupportedException($"Spark {version} not supported.") }; } - /// - /// Calls the necessary functions to create org.apache.spark.broadcast.Broadcast object - /// for Spark versions 2.3.0 and 2.3.1 and returns the JVMObjectReference object. - /// - /// Java Spark context object - /// Broadcast value of type object - /// Returns broadcast variable of type - private JvmObjectReference CreateBroadcast_V2_3_1_AndBelow( - JvmObjectReference javaSparkContext, - object value) - { - WriteToFile(value); - return (JvmObjectReference)javaSparkContext.Jvm.CallStaticJavaMethod( - "org.apache.spark.api.python.PythonRDD", - "readBroadcastFromFile", - javaSparkContext, - _path); - } - /// /// Calls the necessary Spark functions to create org.apache.spark.broadcast.Broadcast - /// object for Spark versions 2.3.2 and above, and returns the JVMObjectReference object. + /// object for Spark versions 2.4.0 and above, and returns the JVMObjectReference object. /// /// Java Spark context object /// SparkContext object /// Broadcast value of type object /// Returns broadcast variable of type - private JvmObjectReference CreateBroadcast_V2_3_2_AndAbove( + private JvmObjectReference CreateBroadcast_V2_4_X( JvmObjectReference javaSparkContext, SparkContext sc, object value) diff --git a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs index 8e52c8e42..78807e3fd 100644 --- a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs +++ b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs @@ -1102,16 +1102,10 @@ private IEnumerable GetRows(string funcName, params object[] args) { object result = _jvmObject.Invoke(funcName, args); Version version = SparkEnvironment.SparkVersion; - return (version.Major, version.Minor, version.Build) switch + return (version.Major, version.Minor) switch { - // In spark 2.3.0, PythonFunction.serveIterator() returns a port number. - (2, 3, 0) => ((int)result, string.Empty, null), - // From spark >= 2.3.1, PythonFunction.serveIterator() returns a pair - // where the first is a port number and the second is the secret - // string to use for the authentication. - (2, 3, _) => ParseConnectionInfo(result, false), - (2, 4, _) => ParseConnectionInfo(result, false), - (3, _, _) => ParseConnectionInfo(result, false), + (2, 4) => ParseConnectionInfo(result, false), + (3, _) => ParseConnectionInfo(result, false), _ => throw new NotSupportedException($"Spark {version} not supported.") }; } diff --git a/src/csharp/Microsoft.Spark/Versions.cs b/src/csharp/Microsoft.Spark/Versions.cs index 62d9adcc6..a909233dd 100644 --- a/src/csharp/Microsoft.Spark/Versions.cs +++ b/src/csharp/Microsoft.Spark/Versions.cs @@ -6,10 +6,6 @@ namespace Microsoft.Spark { internal static class Versions { - internal const string V2_3_0 = "2.3.0"; - internal const string V2_3_1 = "2.3.1"; - internal const string V2_3_2 = "2.3.2"; - internal const string V2_3_3 = "2.3.3"; internal const string V2_4_0 = "2.4.0"; internal const string V2_4_2 = "2.4.2"; internal const string V3_0_0 = "3.0.0"; From 3f618575649519670c6bb56b466435152f2b2b44 Mon Sep 17 00:00:00 2001 From: Steve Suh Date: Thu, 8 Apr 2021 12:31:40 -0700 Subject: [PATCH 3/8] update pipeline --- azure-pipelines.yml | 52 --------------------------------------------- 1 file changed, 52 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 9ab3c7680..72741d518 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -16,11 +16,6 @@ variables: backwardCompatibleRelease: '1.0.0' forwardCompatibleRelease: '1.0.0' - backwardCompatibleTestOptions_Windows_2_3: "" - forwardCompatibleTestOptions_Windows_2_3: "" - backwardCompatibleTestOptions_Linux_2_3: "" - forwardCompatibleTestOptions_Linux_2_3: "" - backwardCompatibleTestOptions_Windows_2_4: "" forwardCompatibleTestOptions_Windows_2_4: "" backwardCompatibleTestOptions_Linux_2_4: "" @@ -206,53 +201,6 @@ stages: backwardCompatibleRelease: $(backwardCompatibleRelease) forwardCompatibleRelease: $(forwardCompatibleRelease) tests: - - version: '2.3.0' - jobOptions: - # 'Hosted Ubuntu 1604' test is disabled due to https://github.com/dotnet/spark/issues/753 - - pool: 'Hosted VS2017' - testOptions: '' - backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Windows_2_3) - forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Windows_2_3) - - version: '2.3.1' - jobOptions: - - pool: 'Hosted VS2017' - testOptions: "" - backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Windows_2_3) - forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Windows_2_3) - - pool: 'Hosted Ubuntu 1604' - testOptions: "" - backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Linux_2_3) - forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Linux_2_3) - - version: '2.3.2' - jobOptions: - - pool: 'Hosted VS2017' - testOptions: "" - backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Windows_2_3) - forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Windows_2_3) - - pool: 'Hosted Ubuntu 1604' - testOptions: "" - backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Linux_2_3) - forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Linux_2_3) - - version: '2.3.3' - jobOptions: - - pool: 'Hosted VS2017' - testOptions: "" - backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Windows_2_3) - forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Windows_2_3) - - pool: 'Hosted Ubuntu 1604' - testOptions: "" - backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Linux_2_3) - forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Linux_2_3) - - version: '2.3.4' - jobOptions: - - pool: 'Hosted VS2017' - testOptions: "" - backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Windows_2_3) - forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Windows_2_3) - - pool: 'Hosted Ubuntu 1604' - testOptions: "" - backwardCompatibleTestOptions: $(backwardCompatibleTestOptions_Linux_2_3) - forwardCompatibleTestOptions: $(forwardCompatibleTestOptions_Linux_2_3) - version: '2.4.0' jobOptions: - pool: 'Hosted VS2017' From c309cfcc47dd8a088507ff58f3798a290fcc0361 Mon Sep 17 00:00:00 2001 From: Steve Suh Date: Thu, 8 Apr 2021 12:59:03 -0700 Subject: [PATCH 4/8] update/merge tests --- .../ML/Feature/StopWordsRemoverTests.cs | 14 +- .../IpcTests/SparkContextTests.cs | 4 +- .../IpcTests/Sql/ColumnTests.cs | 30 ++-- .../IpcTests/Sql/DataFrameReaderTests.cs | 6 +- .../IpcTests/Sql/DataFrameTests.cs | 34 ++--- .../IpcTests/Sql/DataFrameWriterTests.cs | 4 +- .../Sql/Expressions/WindowSpecTests.cs | 4 +- .../IpcTests/Sql/Expressions/WindowTests.cs | 4 +- .../IpcTests/Sql/FunctionsTests.cs | 141 ++++++++---------- .../IpcTests/Sql/RuntimeConfigTests.cs | 14 +- .../IpcTests/Sql/SparkSessionTests.cs | 12 +- .../Sql/Streaming/DataStreamReaderTests.cs | 4 +- .../Sql/Streaming/DataStreamWriterTests.cs | 4 +- .../Streaming/StreamingQueryManagerTests.cs | 4 +- .../Sql/Streaming/StreamingQueryTests.cs | 5 +- 15 files changed, 113 insertions(+), 171 deletions(-) diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StopWordsRemoverTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StopWordsRemoverTests.cs index 4bf614a44..832304e43 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StopWordsRemoverTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/StopWordsRemoverTests.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System.IO; -using Microsoft.Spark.E2ETest.Utils; using Microsoft.Spark.ML.Feature; using Microsoft.Spark.Sql; using Microsoft.Spark.Sql.Types; @@ -23,10 +22,10 @@ public StopWordsRemoverTests(SparkFixture fixture) : base(fixture) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { string expectedUid = "theUidWithOutLocale"; string expectedInputCol = "input_col"; @@ -62,16 +61,9 @@ public void TestSignaturesV2_3_X() Assert.IsType(stopWordsRemover.Transform(input)); TestFeatureBase(stopWordsRemover, "inputCol", "input_col"); - } - /// - /// Test signatures for APIs introduced in Spark 2.4.*. - /// - [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)] - public void TestSignaturesV2_4_X() - { string expectedLocale = "en_GB"; - StopWordsRemover stopWordsRemover = new StopWordsRemover().SetLocale(expectedLocale); + stopWordsRemover.SetLocale(expectedLocale); Assert.Equal(expectedLocale, stopWordsRemover.GetLocale()); } } diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs index 40e5a774e..e0b9fb6a0 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/SparkContextTests.cs @@ -17,13 +17,13 @@ namespace Microsoft.Spark.E2ETest.IpcTests public class SparkContextTests { /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// /// /// For the RDD related tests, refer to . /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { SparkContext sc = SparkContext.GetOrCreate(new SparkConf()); diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/ColumnTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/ColumnTests.cs index 59fa6df99..6ffa2b3c7 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/ColumnTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/ColumnTests.cs @@ -14,10 +14,10 @@ namespace Microsoft.Spark.E2ETest.IpcTests public class ColumnTests { /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { Column col1 = Column("col1"); Column col2 = Column("col2"); @@ -28,22 +28,22 @@ public void TestSignaturesV2_3_X() Assert.IsType(col1 == col2); Assert.IsType(col1.EqualTo(col2)); - + Assert.IsType(col1 != col2); Assert.IsType(col1.NotEqual(col2)); - - Assert.IsType(col1 > col2); - Assert.IsType(col1 > "hello"); - Assert.IsType(col1.Gt(col2)); + + Assert.IsType(col1 > col2); + Assert.IsType(col1 > "hello"); + Assert.IsType(col1.Gt(col2)); Assert.IsType(col1.Gt("hello")); - - Assert.IsType(col1 < col2); - Assert.IsType(col1 < "hello"); - Assert.IsType(col1.Lt(col2)); + + Assert.IsType(col1 < col2); + Assert.IsType(col1 < "hello"); + Assert.IsType(col1.Lt(col2)); Assert.IsType(col1.Lt("hello")); Assert.IsType(col1 <= col2); - Assert.IsType(col1 <= "hello"); + Assert.IsType(col1 <= "hello"); Assert.IsType(col1.Leq(col2)); Assert.IsType(col1.Leq("hello")); @@ -59,7 +59,7 @@ public void TestSignaturesV2_3_X() Assert.IsType(When(col1 == col2, 0).Otherwise(col2)); Assert.IsType(When(col1 == col2, 0).Otherwise("hello")); - + Assert.IsType(col1.Between(col1, col2)); Assert.IsType(col1.Between(1, 3)); @@ -69,7 +69,7 @@ public void TestSignaturesV2_3_X() Assert.IsType(col1 | col2); Assert.IsType(col1.Or(col2)); - + Assert.IsType(col1 & col2); Assert.IsType(col1.And(col2)); @@ -139,7 +139,7 @@ public void TestSignaturesV2_3_X() Assert.IsType(col1.Over(PartitionBy(col1))); Assert.IsType(col1.Over()); - + Assert.Equal("col1", col1.ToString()); Assert.Equal("col2", col2.ToString()); } diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameReaderTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameReaderTests.cs index 8b3ccb648..feb9b33ff 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameReaderTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameReaderTests.cs @@ -20,15 +20,15 @@ public DataFrameReaderTests(SparkFixture fixture) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { DataFrameReader dfr = _spark.Read(); Assert.IsType(dfr.Format("json")); - + Assert.IsType( dfr.Schema( new StructType(new[] diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs index 58403b485..b95d71add 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs @@ -437,10 +437,10 @@ private static FxDataFrame CountCharacters(FxDataFrame dataFrame) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { Assert.IsType(_df["name"]); Assert.IsType(_df["age"]); @@ -569,6 +569,16 @@ public void TestSignaturesV2_3_X() Assert.IsType(df.Sum("age")); Assert.IsType(df.Sum("age", "tempAge")); + + var values = new List { 19, "twenty" }; + + Assert.IsType(df.Pivot("age")); + + Assert.IsType(df.Pivot(Col("age"))); + + Assert.IsType(df.Pivot("age", values)); + + Assert.IsType(df.Pivot(Col("age"), values)); } Assert.IsType(_df.Rollup("age")); @@ -669,32 +679,12 @@ public void TestSignaturesV2_3_X() _df.CreateOrReplaceGlobalTempView("global_view"); Assert.IsType(_df.InputFiles().ToArray()); - } - /// - /// Test signatures for APIs introduced in Spark 2.4.*. - /// - [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)] - public void TestSignaturesV2_4_X() - { _df.IsEmpty(); _df.IntersectAll(_df); _df.ExceptAll(_df); - - { - RelationalGroupedDataset df = _df.GroupBy("name"); - var values = new List { 19, "twenty" }; - - Assert.IsType(df.Pivot("age")); - - Assert.IsType(df.Pivot(Col("age"))); - - Assert.IsType(df.Pivot("age", values)); - - Assert.IsType(df.Pivot(Col("age"), values)); - } } /// diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameWriterTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameWriterTests.cs index c693fdabc..42e4908ea 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameWriterTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameWriterTests.cs @@ -20,10 +20,10 @@ public DataFrameWriterTests(SparkFixture fixture) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { { DataFrameWriter dfw = _spark diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Expressions/WindowSpecTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Expressions/WindowSpecTests.cs index d37ea9570..ea8677ffe 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Expressions/WindowSpecTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Expressions/WindowSpecTests.cs @@ -15,10 +15,10 @@ namespace Microsoft.Spark.E2ETest.IpcTests public class WindowSpecTests { /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { Column col1 = Column("age"); Column col2 = Column("name"); diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Expressions/WindowTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Expressions/WindowTests.cs index 97c9e1508..06359029d 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Expressions/WindowTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Expressions/WindowTests.cs @@ -15,10 +15,10 @@ namespace Microsoft.Spark.E2ETest.IpcTests public class WindowTests { /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { Column col1 = Column("age"); Column col2 = Column("name"); diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/FunctionsTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/FunctionsTests.cs index d2b79609d..831b3d331 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/FunctionsTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/FunctionsTests.cs @@ -23,12 +23,12 @@ public FunctionsTests(SparkFixture fixture) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// The purpose of this test is to ensure that JVM calls can be successfully made. /// Note that this is not testing functionality of each function. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { ////////////////////////////// // Basic Functions @@ -206,6 +206,8 @@ public void TestSignaturesV2_3_X() Assert.IsType(Map(col)); Assert.IsType(Map(col, col)); + Assert.IsType(MapFromArrays(col, col)); + DataFrame df = _spark .Read() .Json($"{TestEnvironment.ResourceDirectory}people.json"); @@ -514,6 +516,7 @@ public void TestSignaturesV2_3_X() Assert.IsType(Minute(col)); Assert.IsType(MonthsBetween(col, col)); + Assert.IsType(MonthsBetween(col, col, false)); Assert.IsType(NextDay(col, "Mon")); @@ -542,8 +545,10 @@ public void TestSignaturesV2_3_X() { // The following APIs are deprecated in Spark 3.0. Assert.IsType(FromUtcTimestamp(col, "GMT+1")); + Assert.IsType(FromUtcTimestamp(col, col)); Assert.IsType(ToUtcTimestamp(col, "GMT+1")); + Assert.IsType(ToUtcTimestamp(col, col)); } Assert.IsType(Window(col, "1 minute", "10 seconds", "5 seconds")); @@ -556,10 +561,33 @@ public void TestSignaturesV2_3_X() Assert.IsType(ArrayContains(col, 12345)); Assert.IsType(ArrayContains(col, "str")); + Assert.IsType(ArraysOverlap(col, col)); + + Assert.IsType(Slice(col, 0, 4)); + + Assert.IsType(ArrayJoin(col, ":", "replacement")); + Assert.IsType(ArrayJoin(col, ":")); + Assert.IsType(Concat()); Assert.IsType(Concat(col)); Assert.IsType(Concat(col, col)); + Assert.IsType(ArrayPosition(col, 1)); + + Assert.IsType(ElementAt(col, 1)); + + Assert.IsType(ArraySort(col)); + + Assert.IsType(ArrayRemove(col, "elementToRemove")); + + Assert.IsType(ArrayDistinct(col)); + + Assert.IsType(ArrayIntersect(col, col)); + + Assert.IsType(ArrayUnion(col, col)); + + Assert.IsType(ArrayExcept(col, col)); + Assert.IsType(Explode(col)); Assert.IsType(ExplodeOuter(col)); @@ -574,9 +602,15 @@ public void TestSignaturesV2_3_X() Assert.IsType(JsonTuple(col, "a", "b")); var options = new Dictionary() { { "hello", "world" } }; + Column schema = SchemaOfJson("[{\"col\":0}]"); Assert.IsType(FromJson(col, "a Int")); Assert.IsType(FromJson(col, "a Int", options)); + Assert.IsType(FromJson(col, schema)); + Assert.IsType(FromJson(col, schema, options)); + + Assert.IsType(SchemaOfJson("{}")); + Assert.IsType(SchemaOfJson(col)); Assert.IsType(ToJson(col)); Assert.IsType(ToJson(col, options)); @@ -587,12 +621,36 @@ public void TestSignaturesV2_3_X() Assert.IsType(SortArray(col, true)); Assert.IsType(SortArray(col, false)); + Assert.IsType(ArrayMin(col)); + + Assert.IsType(ArrayMax(col)); + + Assert.IsType(Shuffle(col)); + Assert.IsType(Reverse(col)); + Assert.IsType(Flatten(col)); + + Assert.IsType(Sequence(col, col, col)); + Assert.IsType(Sequence(col, col)); + + Assert.IsType(ArrayRepeat(col, col)); + Assert.IsType(ArrayRepeat(col, 5)); + Assert.IsType(MapKeys(col)); Assert.IsType(MapValues(col)); + Assert.IsType(MapFromEntries(col)); + + Assert.IsType(ArraysZip()); + Assert.IsType(ArraysZip(col)); + Assert.IsType(ArraysZip(col, col)); + + Assert.IsType(MapConcat()); + Assert.IsType(MapConcat(col)); + Assert.IsType(MapConcat(col, col)); + ////////////////////////////// // Udf Functions ////////////////////////////// @@ -666,85 +724,6 @@ private void TestUdf() (arg) => new Dictionary { { arg, new[] { arg } } }); } - /// - /// Test signatures for APIs introduced in Spark 2.4.*. - /// - [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)] - public void TestSignaturesV2_4_X() - { - Column col = Column("col"); - - col = MapFromArrays(col, col); - - col = MonthsBetween(col, col, false); - - if (SparkSettings.Version < new Version(Versions.V3_0_0)) - { - // The following APIs are deprecated in Spark 3.0. - col = FromUtcTimestamp(col, col); - - col = ToUtcTimestamp(col, col); - } - - col = ArraysOverlap(col, col); - - col = Slice(col, 0, 4); - - col = ArrayJoin(col, ":", "replacement"); - col = ArrayJoin(col, ":"); - - col = ArrayPosition(col, 1); - - col = ElementAt(col, 1); - - col = ArraySort(col); - - col = ArrayRemove(col, "elementToRemove"); - - col = ArrayDistinct(col); - - col = ArrayIntersect(col, col); - - col = ArrayUnion(col, col); - - col = ArrayExcept(col, col); - - var options = new Dictionary() { { "hello", "world" } }; - Column schema = SchemaOfJson("[{\"col\":0}]"); - - col = FromJson(col, schema); - col = FromJson(col, schema, options); - - col = SchemaOfJson("{}"); - col = SchemaOfJson(col); - - col = ArrayMin(col); - - col = ArrayMax(col); - - col = Shuffle(col); - - col = Reverse(col); - - col = Flatten(col); - - col = Sequence(col, col, col); - col = Sequence(col, col); - - col = ArrayRepeat(col, col); - col = ArrayRepeat(col, 5); - - col = MapFromEntries(col); - - col = ArraysZip(); - col = ArraysZip(col); - col = ArraysZip(col, col); - - col = MapConcat(); - col = MapConcat(col); - col = MapConcat(col, col); - } - /// /// Test signatures for APIs introduced in Spark 3.0.*. /// diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/RuntimeConfigTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/RuntimeConfigTests.cs index 4315e5e0c..4994c2803 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/RuntimeConfigTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/RuntimeConfigTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.Spark.E2ETest.Utils; using Microsoft.Spark.Sql; using Xunit; @@ -19,12 +18,12 @@ public RuntimeConfigTests(SparkFixture fixture) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// The purpose of this test is to ensure that JVM calls can be successfully made. /// Note that this is not testing functionality of each function. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { RuntimeConfig conf = _spark.Conf(); @@ -39,15 +38,6 @@ public void TestSignaturesV2_3_X() conf.Unset("stringKey"); Assert.Equal("defaultValue", conf.Get("stringKey", "defaultValue")); Assert.Equal("false", conf.Get("boolKey", "true")); - } - - /// - /// Test signatures for APIs introduced in Spark 2.4.*. - /// - [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)] - public void TestSignaturesV2_4_X() - { - RuntimeConfig conf = _spark.Conf(); Assert.True(conf.IsModifiable("spark.sql.streaming.checkpointLocation")); Assert.False(conf.IsModifiable("missingKey")); diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs index cc1542a42..312822f5d 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; -using Microsoft.Spark.E2ETest.Utils; using Microsoft.Spark.Sql; using Microsoft.Spark.Sql.Catalog; using Microsoft.Spark.Sql.Streaming; @@ -24,12 +23,12 @@ public SparkSessionTests(SparkFixture fixture) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// The purpose of this test is to ensure that JVM calls can be successfully made. /// Note that this is not testing functionality of each function. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { Assert.IsType(_spark.SparkContext); @@ -64,14 +63,7 @@ public void TestSignaturesV2_3_X() Assert.IsType(_spark.Udf()); Assert.IsType(_spark.Catalog); - } - /// - /// Test signatures for APIs introduced in Spark 2.4.*. - /// - [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)] - public void TestSignaturesV2_4_X() - { Assert.IsType(SparkSession.Active()); } diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamReaderTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamReaderTests.cs index 11c6c1a7a..f1a2cf83b 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamReaderTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamReaderTests.cs @@ -25,10 +25,10 @@ public DataStreamReaderTests(SparkFixture fixture) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { DataStreamReader dsr = _spark.ReadStream(); diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs index 783dff5f2..cce66d4d0 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs @@ -28,10 +28,10 @@ public DataStreamWriterTests(SparkFixture fixture) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { DataFrame df = _spark .ReadStream() diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/StreamingQueryManagerTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/StreamingQueryManagerTests.cs index f16bab0a9..a9badc5a6 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/StreamingQueryManagerTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/StreamingQueryManagerTests.cs @@ -21,12 +21,12 @@ public StreamingQueryManagerTests(SparkFixture fixture) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// The purpose of this test is to ensure that JVM calls can be successfully made. /// Note that this is not testing functionality of each function. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { var intMemoryStream = new MemoryStream(_spark); StreamingQuery sq1 = intMemoryStream diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/StreamingQueryTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/StreamingQueryTests.cs index e2efbbd76..ed6fcc745 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/StreamingQueryTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/StreamingQueryTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Linq; using Microsoft.Spark.E2ETest.Utils; using Microsoft.Spark.Sql; using Microsoft.Spark.Sql.Streaming; @@ -21,12 +20,12 @@ public StreamingQueryTests(SparkFixture fixture) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// The purpose of this test is to ensure that JVM calls can be successfully made. /// Note that this is not testing functionality of each function. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { var intMemoryStream = new MemoryStream(_spark); StreamingQuery sq = intMemoryStream From b25c3d4b94bd611bfa0ba7cc9d24398db6a011a6 Mon Sep 17 00:00:00 2001 From: Steve Suh Date: Thu, 8 Apr 2021 13:19:31 -0700 Subject: [PATCH 5/8] remove unused code and unsupported scenario. --- .../PayloadWriter.cs | 53 ++----------------- src/csharp/Microsoft.Spark/RDD.cs | 15 +----- 2 files changed, 6 insertions(+), 62 deletions(-) diff --git a/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs b/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs index d556b6e89..f011455f9 100644 --- a/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs +++ b/src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs @@ -50,20 +50,6 @@ internal interface ITaskContextWriter void Write(Stream stream, TaskContext taskContext); } - /// - /// TaskContextWriter for version 2.3.*. - /// - internal sealed class TaskContextWriterV2_3_X : ITaskContextWriter - { - public void Write(Stream stream, TaskContext taskContext) - { - SerDe.Write(stream, taskContext.StageId); - SerDe.Write(stream, taskContext.PartitionId); - SerDe.Write(stream, taskContext.AttemptNumber); - SerDe.Write(stream, taskContext.AttemptId); - } - } - /// /// TaskContextWriter for version 2.4.*. /// @@ -136,21 +122,9 @@ internal interface IBroadcastVariableWriter } /// - /// BroadcastVariableWriter for version 2.3.0 and 2.3.1. - /// - internal sealed class BroadcastVariableWriterV2_3_0 : IBroadcastVariableWriter - { - public void Write(Stream stream, BroadcastVariables broadcastVars) - { - Debug.Assert(broadcastVars.Count == 0); - SerDe.Write(stream, broadcastVars.Count); - } - } - - /// - /// BroadcastVariableWriter for version 2.3.2 and up. + /// BroadcastVariableWriter for version 2.4.*. /// - internal sealed class BroadcastVariableWriterV2_3_2 : IBroadcastVariableWriter + internal sealed class BroadcastVariableWriterV2_4_X : IBroadcastVariableWriter { public void Write(Stream stream, BroadcastVariables broadcastVars) { @@ -207,25 +181,6 @@ public void Write(Stream stream, Command[] commands) } } - /// - /// CommandWriter for version 2.3.*. - /// - internal sealed class CommandWriterV2_3_X : CommandWriterBase, ICommandWriter - { - public void Write(Stream stream, CommandPayload commandPayload) - { - SerDe.Write(stream, (int)commandPayload.EvalType); - - Write(stream, commandPayload.Commands); - - if ((commandPayload.EvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF) || - (commandPayload.EvalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)) - { - SerDe.Write(stream, "unused timezone"); - } - } - } - /// /// CommandWriter for version 2.4.*. /// @@ -347,13 +302,13 @@ internal PayloadWriter Create(Version version = null) return new PayloadWriter( version, new TaskContextWriterV2_4_X(), - new BroadcastVariableWriterV2_3_2(), + new BroadcastVariableWriterV2_4_X(), new CommandWriterV2_4_X()); case Versions.V3_0_0: return new PayloadWriter( version, new TaskContextWriterV3_0_X(), - new BroadcastVariableWriterV2_3_2(), + new BroadcastVariableWriterV2_4_X(), new CommandWriterV2_4_X()); default: throw new NotSupportedException($"Spark {version} is not supported."); diff --git a/src/csharp/Microsoft.Spark/RDD.cs b/src/csharp/Microsoft.Spark/RDD.cs index d3d6fff9e..86f85898a 100644 --- a/src/csharp/Microsoft.Spark/RDD.cs +++ b/src/csharp/Microsoft.Spark/RDD.cs @@ -303,19 +303,8 @@ internal virtual RDD MapPartitionsWithIndexInternal( "collectAndServe", rddRef.Invoke("rdd")); - if (result is int @port) - { - // In spark 2.3.0, collectToPython() returns a port number. - return (@port, string.Empty); - } - else - { - // From spark >= 2.3.1, collectToPython() returns a pair - // where the first is a port number and the second is the secret - // string to use for the authentication. - var pair = (JvmObjectReference[])result; - return ((int)pair[0].Invoke("intValue"), (string)pair[1].Invoke("toString")); - } + var pair = (JvmObjectReference[])result; + return ((int)pair[0].Invoke("intValue"), (string)pair[1].Invoke("toString")); } /// From 753f85c8d74e68ffb2e349fe38ebc1a1c01f397a Mon Sep 17 00:00:00 2001 From: Steve Suh Date: Thu, 8 Apr 2021 13:27:03 -0700 Subject: [PATCH 6/8] cleanup --- src/csharp/Microsoft.Spark/RDD.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/csharp/Microsoft.Spark/RDD.cs b/src/csharp/Microsoft.Spark/RDD.cs index 86f85898a..e9e2975ae 100644 --- a/src/csharp/Microsoft.Spark/RDD.cs +++ b/src/csharp/Microsoft.Spark/RDD.cs @@ -298,12 +298,11 @@ internal virtual RDD MapPartitionsWithIndexInternal( private (int, string) CollectAndServe() { JvmObjectReference rddRef = GetJvmRef(); - object result = rddRef.Jvm.CallStaticJavaMethod( + var pair = (JvmObjectReference[])rddRef.Jvm.CallStaticJavaMethod( "org.apache.spark.api.python.PythonRDD", "collectAndServe", rddRef.Invoke("rdd")); - var pair = (JvmObjectReference[])result; return ((int)pair[0].Invoke("intValue"), (string)pair[1].Invoke("toString")); } From a0d3a7c836dec62b2df1feeb456ea8c6c1a5ac7e Mon Sep 17 00:00:00 2001 From: Steve Suh Date: Tue, 13 Apr 2021 13:23:56 -0700 Subject: [PATCH 7/8] update test name --- .../Microsoft.Spark.E2ETest/IpcTests/Sql/CatalogTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/CatalogTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/CatalogTests.cs index 40d4649c8..f5f37dd91 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/CatalogTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/CatalogTests.cs @@ -24,10 +24,10 @@ public CatalogTests(SparkFixture fixture) } /// - /// Test signatures for APIs up to Spark 2.3.*. + /// Test signatures for APIs up to Spark 2.4.*. /// [Fact] - public void TestSignaturesV2_3_X() + public void TestSignaturesV2_4_X() { WithTable(_spark, new string[] { "users", "users2", "users3", "users4", "usersp" }, () => { From 5b386b0259151cea06c79de94f760b20e5e8a353 Mon Sep 17 00:00:00 2001 From: Steve Suh Date: Mon, 19 Apr 2021 11:45:37 -0700 Subject: [PATCH 8/8] PR comments. --- .../Microsoft.Spark.Worker/DaemonWorker.cs | 2 +- .../Processor/BroadcastVariableProcessor.cs | 6 +----- .../Microsoft.Spark.Worker/SimpleWorker.cs | 4 ++-- .../Utils/SettingUtils.cs | 18 ++++-------------- src/csharp/Microsoft.Spark/RDD.cs | 3 ++- src/csharp/Microsoft.Spark/Sql/DataFrame.cs | 2 ++ 6 files changed, 12 insertions(+), 23 deletions(-) diff --git a/src/csharp/Microsoft.Spark.Worker/DaemonWorker.cs b/src/csharp/Microsoft.Spark.Worker/DaemonWorker.cs index ecbd3c272..3611ba17f 100644 --- a/src/csharp/Microsoft.Spark.Worker/DaemonWorker.cs +++ b/src/csharp/Microsoft.Spark.Worker/DaemonWorker.cs @@ -110,7 +110,7 @@ private void StartServer(ISocketWrapper listener) bool reuseWorker = "1".Equals(Environment.GetEnvironmentVariable("SPARK_REUSE_WORKER")); - string secret = Utils.SettingUtils.GetWorkerFactorySecret(_version); + string secret = Utils.SettingUtils.GetWorkerFactorySecret(); int taskRunnerId = 1; int numWorkerThreads = 0; diff --git a/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs b/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs index 40a9ffae3..cb1fa5f4a 100644 --- a/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs +++ b/src/csharp/Microsoft.Spark.Worker/Processor/BroadcastVariableProcessor.cs @@ -30,11 +30,7 @@ internal BroadcastVariables Process(Stream stream) var broadcastVars = new BroadcastVariables(); ISocketWrapper socket = null; - if (_version >= new Version(Versions.V2_4_0)) - { - broadcastVars.DecryptionServerNeeded = SerDe.ReadBool(stream); - } - + broadcastVars.DecryptionServerNeeded = SerDe.ReadBool(stream); broadcastVars.Count = Math.Max(SerDe.ReadInt32(stream), 0); if (broadcastVars.DecryptionServerNeeded) diff --git a/src/csharp/Microsoft.Spark.Worker/SimpleWorker.cs b/src/csharp/Microsoft.Spark.Worker/SimpleWorker.cs index 34b6d2081..06e71338b 100644 --- a/src/csharp/Microsoft.Spark.Worker/SimpleWorker.cs +++ b/src/csharp/Microsoft.Spark.Worker/SimpleWorker.cs @@ -23,7 +23,7 @@ internal SimpleWorker(Version version) internal void Run() { - int port = Utils.SettingUtils.GetWorkerFactoryPort(_version); + int port = Utils.SettingUtils.GetWorkerFactoryPort(); Run(port); } @@ -31,7 +31,7 @@ internal void Run(int port) { try { - string secret = Utils.SettingUtils.GetWorkerFactorySecret(_version); + string secret = Utils.SettingUtils.GetWorkerFactorySecret(); s_logger.LogInfo($"RunSimpleWorker() is starting with port = {port}."); diff --git a/src/csharp/Microsoft.Spark.Worker/Utils/SettingUtils.cs b/src/csharp/Microsoft.Spark.Worker/Utils/SettingUtils.cs index a1086be44..73698fef8 100644 --- a/src/csharp/Microsoft.Spark.Worker/Utils/SettingUtils.cs +++ b/src/csharp/Microsoft.Spark.Worker/Utils/SettingUtils.cs @@ -12,20 +12,10 @@ namespace Microsoft.Spark.Worker.Utils /// internal static class SettingUtils { - internal static string GetWorkerFactorySecret(Version version) - { - return (version >= new Version(Versions.V2_4_0)) ? - GetEnvironmentVariable("PYTHON_WORKER_FACTORY_SECRET") : - null; - } + internal static string GetWorkerFactorySecret() => + GetEnvironmentVariable("PYTHON_WORKER_FACTORY_SECRET"); - internal static int GetWorkerFactoryPort(Version version) - { - string portStr = (version >= new Version(Versions.V2_4_0)) ? - GetEnvironmentVariable("PYTHON_WORKER_FACTORY_PORT") : - Console.ReadLine(); - - return int.Parse(portStr.Trim()); - } + internal static int GetWorkerFactoryPort() => + int.Parse(GetEnvironmentVariable("PYTHON_WORKER_FACTORY_PORT").Trim()); } } diff --git a/src/csharp/Microsoft.Spark/RDD.cs b/src/csharp/Microsoft.Spark/RDD.cs index e9e2975ae..8c23cc87a 100644 --- a/src/csharp/Microsoft.Spark/RDD.cs +++ b/src/csharp/Microsoft.Spark/RDD.cs @@ -298,11 +298,12 @@ internal virtual RDD MapPartitionsWithIndexInternal( private (int, string) CollectAndServe() { JvmObjectReference rddRef = GetJvmRef(); + // collectToPython() returns a pair where the first is a port number + // and the second is the secret string to use for the authentication. var pair = (JvmObjectReference[])rddRef.Jvm.CallStaticJavaMethod( "org.apache.spark.api.python.PythonRDD", "collectAndServe", rddRef.Invoke("rdd")); - return ((int)pair[0].Invoke("intValue"), (string)pair[1].Invoke("toString")); } diff --git a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs index 78807e3fd..e3adc8ce5 100644 --- a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs +++ b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs @@ -1104,6 +1104,8 @@ private IEnumerable GetRows(string funcName, params object[] args) Version version = SparkEnvironment.SparkVersion; return (version.Major, version.Minor) switch { + // PythonFunction.serveIterator() returns a pair where the first is a port + // number and the second is the secret string to use for the authentication. (2, 4) => ParseConnectionInfo(result, false), (3, _) => ParseConnectionInfo(result, false), _ => throw new NotSupportedException($"Spark {version} not supported.")