Skip to content

Commit

Permalink
[SPARK-44263][CONNECT] Custom Interceptors Support
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Extend SparkSession to allow custom interceptors (grpc's ClientInterceptor)

### Why are the changes needed?
This can be used for different purposes, such as customized authentication mechanisms

This is similar to Python which allows custom channel builders (#40993)

### Does this PR introduce _any_ user-facing change?
SparkSession allows custom interceptors. The proposed usage is

```
val interceptor = new ClientInterceptor {....}
val session = SparkSession.builder().interceptor(interceptor).create()
```

Or same with more than one interceptor

### How was this patch tested?
UT

Closes #41880 from cdkrot/scala_channel_builder_2.

Lead-authored-by: Alice Sayutina <[email protected]>
Co-authored-by: Alice Sayutina <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
2 people authored and HyukjinKwon committed Jul 11, 2023
1 parent c5a23e9 commit eb07110
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag

import com.google.common.cache.{CacheBuilder, CacheLoader}
import io.grpc.ClientInterceptor
import org.apache.arrow.memory.RootAllocator

import org.apache.spark.annotation.{DeveloperApi, Experimental}
Expand Down Expand Up @@ -629,7 +630,7 @@ object SparkSession extends Logging {
* Create a new [[SparkSession]] based on the connect client [[Configuration]].
*/
private[sql] def create(configuration: Configuration): SparkSession = {
new SparkSession(new SparkConnectClient(configuration), cleaner, planIdGenerator)
new SparkSession(configuration.toSparkConnectClient, cleaner, planIdGenerator)
}

/**
Expand All @@ -656,6 +657,18 @@ object SparkSession extends Logging {
this
}

/**
* Add an interceptor [[ClientInterceptor]] to be used during channel creation.
*
* Note that interceptors added last are executed first by gRPC.
*
* @since 3.5.0
*/
def interceptor(interceptor: ClientInterceptor): Builder = {
builder.interceptor(interceptor)
this
}

private[sql] def client(client: SparkConnectClient): Builder = {
this.client = client
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ private[sql] class SparkConnectClient(
private[sql] val configuration: SparkConnectClient.Configuration,
private val channel: ManagedChannel) {

def this(configuration: SparkConnectClient.Configuration) =
this(configuration, configuration.createChannel())

private val userContext: UserContext = configuration.userContext

private[this] val bstub = new CustomSparkConnectBlockingStub(channel, configuration.retryPolicy)
Expand Down Expand Up @@ -198,7 +195,7 @@ private[sql] class SparkConnectClient(
bstub.interrupt(request)
}

def copy(): SparkConnectClient = new SparkConnectClient(configuration)
def copy(): SparkConnectClient = configuration.toSparkConnectClient

/**
* Add a single artifact to the client session.
Expand Down Expand Up @@ -463,7 +460,18 @@ object SparkConnectClient {
this
}

def build(): SparkConnectClient = new SparkConnectClient(_configuration)
/**
* Add an interceptor to be used during channel creation.
*
* Note that interceptors added last are executed first by gRPC.
*/
def interceptor(interceptor: ClientInterceptor): Builder = {
val interceptors = _configuration.interceptors ++ List(interceptor)
_configuration = _configuration.copy(interceptors = interceptors)
this
}

def build(): SparkConnectClient = _configuration.toSparkConnectClient
}

/**
Expand All @@ -478,7 +486,8 @@ object SparkConnectClient {
isSslEnabled: Option[Boolean] = None,
metadata: Map[String, String] = Map.empty,
userAgent: String = DEFAULT_USER_AGENT,
retryPolicy: GrpcRetryHandler.RetryPolicy = GrpcRetryHandler.RetryPolicy()) {
retryPolicy: GrpcRetryHandler.RetryPolicy = GrpcRetryHandler.RetryPolicy(),
interceptors: List[ClientInterceptor] = List.empty) {

def userContext: proto.UserContext = {
val builder = proto.UserContext.newBuilder()
Expand Down Expand Up @@ -509,12 +518,18 @@ object SparkConnectClient {

def createChannel(): ManagedChannel = {
val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, credentials)

if (metadata.nonEmpty) {
channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata))
}

interceptors.foreach(channelBuilder.intercept(_))

channelBuilder.maxInboundMessageSize(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE)
channelBuilder.build()
}

def toSparkConnectClient: SparkConnectClient = new SparkConnectClient(this, createChannel())
}

/**
Expand All @@ -540,10 +555,6 @@ object SparkConnectClient {
}
})
}

override def thisUsesUnstableApi(): Unit = {
// Marks this API is not stable. Left empty on purpose.
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql

import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor}

import org.apache.spark.sql.connect.client.util.ConnectFunSuite

/**
Expand Down Expand Up @@ -73,4 +75,22 @@ class SparkSessionSuite extends ConnectFunSuite {
session2.close()
}
}

test("Custom Interceptor") {
val session = SparkSession
.builder()
.interceptor(new ClientInterceptor {
override def interceptCall[ReqT, RespT](
methodDescriptor: MethodDescriptor[ReqT, RespT],
callOptions: CallOptions,
channel: Channel): ClientCall[ReqT, RespT] = {
throw new RuntimeException("Blocked")
}
})
.create()

assertThrows[RuntimeException] {
session.range(10).count()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ object CheckConnectJvmClientCompatibility {
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession#Builder.create"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession#Builder.interceptor"),

// Steaming API
ProblemFilters.exclude[MissingTypesProblem](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import scala.collection.mutable

import io.grpc.{Server, Status, StatusRuntimeException}
import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor, Server, Status, StatusRuntimeException}
import io.grpc.netty.NettyServerBuilder
import io.grpc.stub.StreamObserver
import org.scalatest.BeforeAndAfterEach
Expand Down Expand Up @@ -123,6 +123,28 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
assert(df.plan === service.getAndClearLatestInputPlan())
}

test("Custom Interceptor") {
startDummyServer(0)
client = SparkConnectClient
.builder()
.connectionString(s"sc://localhost:${server.getPort}")
.interceptor(new ClientInterceptor {
override def interceptCall[ReqT, RespT](
methodDescriptor: MethodDescriptor[ReqT, RespT],
callOptions: CallOptions,
channel: Channel): ClientCall[ReqT, RespT] = {
throw new RuntimeException("Blocked")
}
})
.build()

val session = SparkSession.builder().client(client).create()

assertThrows[RuntimeException] {
session.range(10).count()
}
}

private case class TestPackURI(
connectionString: String,
isCorrect: Boolean,
Expand Down

0 comments on commit eb07110

Please sign in to comment.