Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44263][CONNECT] Custom Interceptors Support #41880

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ import java.io.Closeable
import java.net.URI
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.AtomicLong

import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag

import com.google.common.cache.{CacheBuilder, CacheLoader}
import io.grpc.ClientInterceptor
cdkrot marked this conversation as resolved.
Show resolved Hide resolved
import org.apache.arrow.memory.RootAllocator

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
Expand Down Expand Up @@ -629,7 +627,7 @@ object SparkSession extends Logging {
* Create a new [[SparkSession]] based on the connect client [[Configuration]].
*/
private[sql] def create(configuration: Configuration): SparkSession = {
new SparkSession(new SparkConnectClient(configuration), cleaner, planIdGenerator)
new SparkSession(configuration.toSparkConnectClient, cleaner, planIdGenerator)
}

/**
Expand All @@ -656,6 +654,16 @@ object SparkSession extends Logging {
this
}

/**
* Add an interceptor to be used during channel creation.
*
* Note that interceptors added last are executed first by gRPC.
cdkrot marked this conversation as resolved.
Show resolved Hide resolved
*/
def interceptor(interceptor: ClientInterceptor): Builder = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc please :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the public facing part of the interface and needs the most extensive documentation,.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, good point :). I didn't do it because other public methods here are undocumented, but that's probably not a good idea either

builder.interceptor(interceptor)
this
}

private[sql] def client(client: SparkConnectClient): Builder = {
this.client = client
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ private[sql] class SparkConnectClient(
private[sql] val configuration: SparkConnectClient.Configuration,
private val channel: ManagedChannel) {

def this(configuration: SparkConnectClient.Configuration) =
this(configuration, configuration.createChannel())

private val userContext: UserContext = configuration.userContext

private[this] val bstub = new CustomSparkConnectBlockingStub(channel, configuration.retryPolicy)
Expand Down Expand Up @@ -198,7 +195,7 @@ private[sql] class SparkConnectClient(
bstub.interrupt(request)
}

def copy(): SparkConnectClient = new SparkConnectClient(configuration)
def copy(): SparkConnectClient = configuration.toSparkConnectClient

/**
* Add a single artifact to the client session.
Expand Down Expand Up @@ -463,7 +460,18 @@ object SparkConnectClient {
this
}

def build(): SparkConnectClient = new SparkConnectClient(_configuration)
/**
* Add an interceptor to be used during channel creation.
*
* Note that interceptors added last are executed first by gRPC.
*/
def interceptor(interceptor: ClientInterceptor): Builder = {
val interceptors = _configuration.interceptors ++ List(interceptor)
_configuration = _configuration.copy(interceptors = interceptors)
this
}

def build(): SparkConnectClient = _configuration.toSparkConnectClient
}

/**
Expand All @@ -478,7 +486,8 @@ object SparkConnectClient {
isSslEnabled: Option[Boolean] = None,
metadata: Map[String, String] = Map.empty,
userAgent: String = DEFAULT_USER_AGENT,
retryPolicy: GrpcRetryHandler.RetryPolicy = GrpcRetryHandler.RetryPolicy()) {
retryPolicy: GrpcRetryHandler.RetryPolicy = GrpcRetryHandler.RetryPolicy(),
interceptors: List[ClientInterceptor] = List.empty) {

def userContext: proto.UserContext = {
val builder = proto.UserContext.newBuilder()
Expand Down Expand Up @@ -509,12 +518,18 @@ object SparkConnectClient {

def createChannel(): ManagedChannel = {
val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, credentials)

if (metadata.nonEmpty) {
channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata))
}

interceptors.forEach(channelBuilder.intercept(_))

channelBuilder.maxInboundMessageSize(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE)
channelBuilder.build()
}

def toSparkConnectClient: SparkConnectClient = new SparkConnectClient(this, createChannel())
}

/**
Expand All @@ -540,10 +555,6 @@ object SparkConnectClient {
}
})
}

override def thisUsesUnstableApi(): Unit = {
// Marks this API is not stable. Left empty on purpose.
}
Copy link
Contributor Author

@cdkrot cdkrot Jul 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is no longer unstable (was stabilized) so this can be removed. See grpc/grpc-java#1914

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
*/
package org.apache.spark.sql


import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor, Server}

import org.apache.spark.sql.connect.client.util.ConnectFunSuite

/**
* Tests for non-dataframe related SparkSession operations.
*/
class SparkSessionSuite extends ConnectFunSuite {
class SparkSessionSuite extends ConnectFunSuite with BeforeAndAfterEach {
test("default") {
val session = SparkSession.builder().getOrCreate()
assert(session.client.configuration.host == "localhost")
Expand Down Expand Up @@ -73,4 +76,20 @@ class SparkSessionSuite extends ConnectFunSuite {
session2.close()
}
}

test("Custom Interceptor") {
val session = SparkSession
.builder()
.interceptor(new ClientInterceptor {
override def interceptCall[ReqT, RespT](methodDescriptor: MethodDescriptor[ReqT, RespT],
callOptions: CallOptions,
channel: Channel): ClientCall[ReqT, RespT] = {
throw new RuntimeException("Blocked")
}
}).create()

assertThrows[RuntimeException] {
session.range(10).count()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import scala.collection.mutable

import io.grpc.{Server, Status, StatusRuntimeException}
import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor, Server}
import io.grpc.netty.NettyServerBuilder
import io.grpc.stub.StreamObserver
import org.scalatest.BeforeAndAfterEach
Expand Down Expand Up @@ -123,6 +123,26 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
assert(df.plan === service.getAndClearLatestInputPlan())
}

test("Custom Interceptor") {
startDummyServer(0)
client = SparkConnectClient
.builder()
.connectionString(s"sc://localhost:${server.getPort}")
.interceptor(new ClientInterceptor {
override def interceptCall[ReqT, RespT](methodDescriptor: MethodDescriptor[ReqT, RespT],
callOptions: CallOptions,
channel: Channel): ClientCall[ReqT, RespT] = {
throw new RuntimeException("Blocked")
}
}).build()

val session = SparkSession.builder().client(client).create()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client() method is private[sql] for anyone providing any kinds of wrappers on top of the SparkSession for custom auth integrations it's not possible to use this method.

Copy link
Contributor Author

@cdkrot cdkrot Jul 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposed usage is rather

SparkSession.builder().interceptor(my interceptor).create(), and this one is public.

Perhaps I should add a test doing this exact line nearby

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now there is an external api test as well (in SparkSessionSuite)


assertThrows[RuntimeException] {
session.range(10).count()
}
}

private case class TestPackURI(
connectionString: String,
isCorrect: Boolean,
Expand Down