Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3.3 and 3.4 support initially added #1150

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/scala/microsoft-spark-3-3/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.microsoft.scala</groupId>
<artifactId>microsoft-spark</artifactId>
<version>${microsoft-spark.version}</version>
</parent>
<artifactId>microsoft-spark-3-3_2.12</artifactId>
<inceptionYear>2019</inceptionYear>
<properties>
<encoding>UTF-8</encoding>
<scala.version>2.12.10</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<spark.version>3.3.0</spark.version>
</properties>

<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.specs</groupId>
<artifactId>specs</artifactId>
<version>1.2.5</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<sourceDirectory>src/main/scala</sourceDirectory>
<testSourceDirectory>src/test/scala</testSourceDirectory>
<plugins>
<plugin>
<groupId>org.scala-tools</groupId>
<artifactId>maven-scala-plugin</artifactId>
<version>2.15.2</version>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
<configuration>
<scalaVersion>${scala.version}</scalaVersion>
<args>
<arg>-target:jvm-1.8</arg>
<arg>-deprecation</arg>
<arg>-feature</arg>
</args>
</configuration>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/

package org.apache.spark.api.dotnet

import java.io.DataOutputStream

import org.apache.spark.internal.Logging

import scala.collection.mutable.Queue

/**
* CallbackClient is used to communicate with the Dotnet CallbackServer.
* The client manages and maintains a pool of open CallbackConnections.
* Any callback request is delegated to a new CallbackConnection or
* unused CallbackConnection.
* @param address The address of the Dotnet CallbackServer
* @param port The port of the Dotnet CallbackServer
*/
class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging {
private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]()

private[this] var isShutdown: Boolean = false

final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit =
getOrCreateConnection() match {
case Some(connection) =>
try {
connection.send(callbackId, writeBody)
addConnection(connection)
} catch {
case e: Exception =>
logError(s"Error calling callback [callback id = $callbackId].", e)
connection.close()
throw e
}
case None => throw new Exception("Unable to get or create connection.")
}

private def getOrCreateConnection(): Option[CallbackConnection] = synchronized {
if (isShutdown) {
logInfo("Cannot get or create connection while client is shutdown.")
return None
}

if (connectionPool.nonEmpty) {
return Some(connectionPool.dequeue())
}

Some(new CallbackConnection(serDe, address, port))
}

private def addConnection(connection: CallbackConnection): Unit = synchronized {
assert(connection != null)
connectionPool.enqueue(connection)
}

def shutdown(): Unit = synchronized {
if (isShutdown) {
logInfo("Shutdown called, but already shutdown.")
return
}

logInfo("Shutting down.")
connectionPool.foreach(_.close)
connectionPool.clear
isShutdown = true
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/

package org.apache.spark.api.dotnet

import java.io.{ByteArrayOutputStream, Closeable, DataInputStream, DataOutputStream}
import java.net.Socket

import org.apache.spark.internal.Logging

/**
* CallbackConnection is used to process the callback communication
* between the JVM and Dotnet. It uses a TCP socket to communicate with
* the Dotnet CallbackServer and the socket is expected to be reused.
* @param address The address of the Dotnet CallbackServer
* @param port The port of the Dotnet CallbackServer
*/
class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging {
private[this] val socket: Socket = new Socket(address, port)
private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream)
private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream)

def send(
callbackId: Int,
writeBody: (DataOutputStream, SerDe) => Unit): Unit = {
logInfo(s"Calling callback [callback id = $callbackId] ...")

try {
serDe.writeInt(outputStream, CallbackFlags.CALLBACK)
serDe.writeInt(outputStream, callbackId)

val byteArrayOutputStream = new ByteArrayOutputStream()
writeBody(new DataOutputStream(byteArrayOutputStream), serDe)
serDe.writeInt(outputStream, byteArrayOutputStream.size)
byteArrayOutputStream.writeTo(outputStream);
} catch {
case e: Exception => {
throw new Exception("Error writing to stream.", e)
}
}

logInfo(s"Signaling END_OF_STREAM.")
try {
serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM)
outputStream.flush()

val endOfStreamResponse = readFlag(inputStream)
endOfStreamResponse match {
case CallbackFlags.END_OF_STREAM =>
logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.")
case _ => {
throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " +
s"Received: $endOfStreamResponse")
}
}
} catch {
case e: Exception => {
throw new Exception("Error while verifying end of stream.", e)
}
}
}

def close(): Unit = {
try {
serDe.writeInt(outputStream, CallbackFlags.CLOSE)
outputStream.flush()
} catch {
case e: Exception => logInfo("Unable to send close to .NET callback server.", e)
}

close(socket)
close(outputStream)
close(inputStream)
}

private def close(s: Socket): Unit = {
try {
assert(s != null)
s.close()
} catch {
case e: Exception => logInfo("Unable to close socket.", e)
}
}

private def close(c: Closeable): Unit = {
try {
assert(c != null)
c.close()
} catch {
case e: Exception => logInfo("Unable to close closeable.", e)
}
}

private def readFlag(inputStream: DataInputStream): Int = {
val callbackFlag = serDe.readInt(inputStream)
if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) {
val exceptionMessage = serDe.readString(inputStream)
throw new DotnetException(exceptionMessage)
}
callbackFlag
}

private object CallbackFlags {
val CLOSE: Int = -1
val CALLBACK: Int = -2
val DOTNET_EXCEPTION_THROWN: Int = -3
val END_OF_STREAM: Int = -4
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/

package org.apache.spark.api.dotnet

import java.net.InetSocketAddress
import java.util.concurrent.TimeUnit
import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
import io.netty.handler.codec.LengthFieldBasedFrameDecoder
import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_NUM_BACKEND_THREADS
import org.apache.spark.{SparkConf, SparkEnv}

/**
* Netty server that invokes JVM calls based upon receiving messages from .NET.
* The implementation mirrors the RBackend.
*
*/
class DotnetBackend extends Logging {
self => // for accessing the this reference in inner class(ChannelInitializer)
private[this] var channelFuture: ChannelFuture = _
private[this] var bootstrap: ServerBootstrap = _
private[this] var bossGroup: EventLoopGroup = _
private[this] val objectTracker = new JVMObjectTracker

@volatile
private[dotnet] var callbackClient: Option[CallbackClient] = None

def init(portNumber: Int): Int = {
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS)
logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.")
bossGroup = new NioEventLoopGroup(numBackendThreads)
val workerGroup = bossGroup

bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(classOf[NioServerSocketChannel])

bootstrap.childHandler(new ChannelInitializer[SocketChannel]() {
def initChannel(ch: SocketChannel): Unit = {
ch.pipeline()
.addLast("encoder", new ByteArrayEncoder())
.addLast(
"frameDecoder",
// maxFrameLength = 2G
// lengthFieldOffset = 0
// lengthFieldLength = 4
// lengthAdjustment = 0
// initialBytesToStrip = 4, i.e. strip out the length field itself
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
.addLast("handler", new DotnetBackendHandler(self, objectTracker))
}
})

channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber))
channelFuture.syncUninterruptibly()
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort
}

private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized {
callbackClient = callbackClient match {
case Some(_) => throw new Exception("Callback client already set.")
case None =>
logInfo(s"Connecting to a callback server at $address:$port")
Some(new CallbackClient(new SerDe(objectTracker), address, port))
}
}

private[dotnet] def shutdownCallbackClient(): Unit = synchronized {
callbackClient match {
case Some(client) => client.shutdown()
case None => logInfo("Callback server has already been shutdown.")
}
callbackClient = None
}

def run(): Unit = {
channelFuture.channel.closeFuture().syncUninterruptibly()
}

def close(): Unit = {
if (channelFuture != null) {
// close is a local operation and should finish within milliseconds; timeout just to be safe
channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS)
channelFuture = null
}
if (bootstrap != null && bootstrap.config().group() != null) {
bootstrap.config().group().shutdownGracefully()
}
if (bootstrap != null && bootstrap.config().childGroup() != null) {
bootstrap.config().childGroup().shutdownGracefully()
}
bootstrap = null

objectTracker.clear()

// Send close to .NET callback server.
shutdownCallbackClient()

// Shutdown the thread pool whose executors could still be running.
ThreadPool.shutdown()
}
}
Loading