Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44263][CONNECT] Custom Interceptors Support #41880

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ import java.io.Closeable
import java.net.URI
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.AtomicLong

import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag

import com.google.common.cache.{CacheBuilder, CacheLoader}
import io.grpc.ClientInterceptor
cdkrot marked this conversation as resolved.
Show resolved Hide resolved
import org.apache.arrow.memory.RootAllocator

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
Expand Down Expand Up @@ -629,7 +627,7 @@ object SparkSession extends Logging {
* Create a new [[SparkSession]] based on the connect client [[Configuration]].
*/
private[sql] def create(configuration: Configuration): SparkSession = {
new SparkSession(new SparkConnectClient(configuration), cleaner, planIdGenerator)
new SparkSession(configuration.toSparkConnectClient, cleaner, planIdGenerator)
}

/**
Expand All @@ -656,6 +654,11 @@ object SparkSession extends Logging {
this
}

def interceptor(interceptor: ClientInterceptor): Builder = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc please :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the public facing part of the interface and needs the most extensive documentation,.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, good point :). I didn't do it because other public methods here are undocumented, but that's probably not a good idea either

builder.interceptor(interceptor)
this
}

private[sql] def client(client: SparkConnectClient): Builder = {
this.client = client
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ private[sql] class SparkConnectClient(
private[sql] val configuration: SparkConnectClient.Configuration,
private val channel: ManagedChannel) {

def this(configuration: SparkConnectClient.Configuration) =
this(configuration, configuration.createChannel())

private val userContext: UserContext = configuration.userContext

private[this] val stub = new CustomSparkConnectBlockingStub(channel)
Expand Down Expand Up @@ -197,7 +194,7 @@ private[sql] class SparkConnectClient(
stub.interrupt(request)
}

def copy(): SparkConnectClient = new SparkConnectClient(configuration)
def copy(): SparkConnectClient = configuration.toSparkConnectClient

/**
* Add a single artifact to the client session.
Expand Down Expand Up @@ -457,7 +454,18 @@ object SparkConnectClient {
this
}

def build(): SparkConnectClient = new SparkConnectClient(_configuration)
/**
* Add an interceptor to be used during channel creation.
*
* Note that interceptors added last are executed first by grpc.
cdkrot marked this conversation as resolved.
Show resolved Hide resolved
*/
def interceptor(interceptor: ClientInterceptor): Builder = {
val interceptors = _configuration.interceptors ++ List(interceptor)
_configuration = _configuration.copy(interceptors = interceptors)
this
}

def build(): SparkConnectClient = _configuration.toSparkConnectClient
}

/**
Expand All @@ -471,7 +479,8 @@ object SparkConnectClient {
token: Option[String] = None,
isSslEnabled: Option[Boolean] = None,
metadata: Map[String, String] = Map.empty,
userAgent: String = DEFAULT_USER_AGENT) {
userAgent: String = DEFAULT_USER_AGENT,
interceptors: List[ClientInterceptor] = List.empty) {

def userContext: proto.UserContext = {
val builder = proto.UserContext.newBuilder()
Expand Down Expand Up @@ -502,12 +511,20 @@ object SparkConnectClient {

def createChannel(): ManagedChannel = {
val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, credentials)

if (metadata.nonEmpty) {
channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata))
}

for (interceptor <- interceptors) {
channelBuilder.intercept(interceptor)
}
cdkrot marked this conversation as resolved.
Show resolved Hide resolved

channelBuilder.maxInboundMessageSize(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE)
channelBuilder.build()
}

def toSparkConnectClient: SparkConnectClient = new SparkConnectClient(this, createChannel())
}

/**
Expand All @@ -533,10 +550,6 @@ object SparkConnectClient {
}
})
}

override def thisUsesUnstableApi(): Unit = {
// Marks this API is not stable. Left empty on purpose.
}
Copy link
Contributor Author

@cdkrot cdkrot Jul 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is no longer unstable (was stabilized) so this can be removed. See grpc/grpc-java#1914

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.spark.sql.connect.client

import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
import scala.collection.mutable

Expand Down Expand Up @@ -123,6 +122,28 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
assert(df.plan === service.getAndClearLatestInputPlan())
}

test("CustomInterceptor") {
startDummyServer(0)
client = SparkConnectClient
.builder()
.connectionString(s"sc://localhost:${server.getPort}")
.interceptor(new ClientInterceptor {
override def interceptCall[ReqT, RespT](
methodDescriptor: MethodDescriptor[ReqT, RespT],
callOptions: CallOptions,
channel: Channel): ClientCall[ReqT, RespT] = {
throw new RuntimeException("Blocked")
}
})
.build()

val session = SparkSession.builder().client(client).create()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client() method is private[sql] for anyone providing any kinds of wrappers on top of the SparkSession for custom auth integrations it's not possible to use this method.

Copy link
Contributor Author

@cdkrot cdkrot Jul 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposed usage is rather

SparkSession.builder().interceptor(my interceptor).create(), and this one is public.

Perhaps I should add a test doing this exact line nearby

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now there is an external api test as well (in SparkSessionSuite)


assertThrows[RuntimeException] {
session.range(10).count()
}
}

private case class TestPackURI(
connectionString: String,
isCorrect: Boolean,
Expand Down