From 6755f1c908d76e8cd14834e09c5c4395a15e9999 Mon Sep 17 00:00:00 2001 From: Alice Sayutina Date: Tue, 11 Jul 2023 17:33:57 +0900 Subject: [PATCH] [SPARK-44263][CONNECT] Custom Interceptors Support ### What changes were proposed in this pull request? Extend SparkSession to allow custom interceptors (grpc's ClientInterceptor) ### Why are the changes needed? This can be used for different purposes, such as customized authentication mechanisms This is similar to Python which allows custom channel builders (https://github.com/apache/spark/pull/40993) ### Does this PR introduce _any_ user-facing change? SparkSession allows custom interceptors. The proposed usage is ``` val interceptor = new ClientInterceptor {....} val session = SparkSession.builder().interceptor(interceptor).create() ``` Or same with more than one interceptor ### How was this patch tested? UT Closes #41880 from cdkrot/scala_channel_builder_2. Lead-authored-by: Alice Sayutina Co-authored-by: Alice Sayutina Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/sql/SparkSession.scala | 15 ++++++++- .../connect/client/SparkConnectClient.scala | 31 +++++++++++++------ .../apache/spark/sql/SparkSessionSuite.scala | 20 ++++++++++++ .../CheckConnectJvmClientCompatibility.scala | 2 ++ .../client/SparkConnectClientSuite.scala | 24 +++++++++++++- 5 files changed, 80 insertions(+), 12 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 54e9102c55ce5..529ba97c40dff 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import com.google.common.cache.{CacheBuilder, CacheLoader} +import io.grpc.ClientInterceptor import org.apache.arrow.memory.RootAllocator import org.apache.spark.annotation.{DeveloperApi, Experimental} @@ -629,7 +630,7 @@ object SparkSession extends Logging { * Create a new [[SparkSession]] based on the connect client [[Configuration]]. */ private[sql] def create(configuration: Configuration): SparkSession = { - new SparkSession(new SparkConnectClient(configuration), cleaner, planIdGenerator) + new SparkSession(configuration.toSparkConnectClient, cleaner, planIdGenerator) } /** @@ -656,6 +657,18 @@ object SparkSession extends Logging { this } + /** + * Add an interceptor [[ClientInterceptor]] to be used during channel creation. + * + * Note that interceptors added last are executed first by gRPC. + * + * @since 3.5.0 + */ + def interceptor(interceptor: ClientInterceptor): Builder = { + builder.interceptor(interceptor) + this + } + private[sql] def client(client: SparkConnectClient): Builder = { this.client = client this diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 3fb06dab89aae..b41ae5555bf3a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -35,9 +35,6 @@ private[sql] class SparkConnectClient( private[sql] val configuration: SparkConnectClient.Configuration, private val channel: ManagedChannel) { - def this(configuration: SparkConnectClient.Configuration) = - this(configuration, configuration.createChannel()) - private val userContext: UserContext = configuration.userContext private[this] val bstub = new CustomSparkConnectBlockingStub(channel, configuration.retryPolicy) @@ -198,7 +195,7 @@ private[sql] class SparkConnectClient( bstub.interrupt(request) } - def copy(): SparkConnectClient = new SparkConnectClient(configuration) + def copy(): SparkConnectClient = configuration.toSparkConnectClient /** * Add a single artifact to the client session. @@ -463,7 +460,18 @@ object SparkConnectClient { this } - def build(): SparkConnectClient = new SparkConnectClient(_configuration) + /** + * Add an interceptor to be used during channel creation. + * + * Note that interceptors added last are executed first by gRPC. + */ + def interceptor(interceptor: ClientInterceptor): Builder = { + val interceptors = _configuration.interceptors ++ List(interceptor) + _configuration = _configuration.copy(interceptors = interceptors) + this + } + + def build(): SparkConnectClient = _configuration.toSparkConnectClient } /** @@ -478,7 +486,8 @@ object SparkConnectClient { isSslEnabled: Option[Boolean] = None, metadata: Map[String, String] = Map.empty, userAgent: String = DEFAULT_USER_AGENT, - retryPolicy: GrpcRetryHandler.RetryPolicy = GrpcRetryHandler.RetryPolicy()) { + retryPolicy: GrpcRetryHandler.RetryPolicy = GrpcRetryHandler.RetryPolicy(), + interceptors: List[ClientInterceptor] = List.empty) { def userContext: proto.UserContext = { val builder = proto.UserContext.newBuilder() @@ -509,12 +518,18 @@ object SparkConnectClient { def createChannel(): ManagedChannel = { val channelBuilder = Grpc.newChannelBuilderForAddress(host, port, credentials) + if (metadata.nonEmpty) { channelBuilder.intercept(new MetadataHeaderClientInterceptor(metadata)) } + + interceptors.foreach(channelBuilder.intercept(_)) + channelBuilder.maxInboundMessageSize(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE) channelBuilder.build() } + + def toSparkConnectClient: SparkConnectClient = new SparkConnectClient(this, createChannel()) } /** @@ -540,10 +555,6 @@ object SparkConnectClient { } }) } - - override def thisUsesUnstableApi(): Unit = { - // Marks this API is not stable. Left empty on purpose. - } } /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 48d55311da0e8..97fb46bf48af4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql +import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} + import org.apache.spark.sql.connect.client.util.ConnectFunSuite /** @@ -73,4 +75,22 @@ class SparkSessionSuite extends ConnectFunSuite { session2.close() } } + + test("Custom Interceptor") { + val session = SparkSession + .builder() + .interceptor(new ClientInterceptor { + override def interceptCall[ReqT, RespT]( + methodDescriptor: MethodDescriptor[ReqT, RespT], + callOptions: CallOptions, + channel: Channel): ClientCall[ReqT, RespT] = { + throw new RuntimeException("Blocked") + } + }) + .create() + + assertThrows[RuntimeException] { + session.range(10).count() + } + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 576dadd3d9f61..dded96a0b13c9 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -384,6 +384,8 @@ object CheckConnectJvmClientCompatibility { ), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.create"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession#Builder.interceptor"), // Steaming API ProblemFilters.exclude[MissingTypesProblem]( diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 54eea832b753c..495b4b20b0863 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable -import io.grpc.{Server, Status, StatusRuntimeException} +import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor, Server, Status, StatusRuntimeException} import io.grpc.netty.NettyServerBuilder import io.grpc.stub.StreamObserver import org.scalatest.BeforeAndAfterEach @@ -123,6 +123,28 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { assert(df.plan === service.getAndClearLatestInputPlan()) } + test("Custom Interceptor") { + startDummyServer(0) + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .interceptor(new ClientInterceptor { + override def interceptCall[ReqT, RespT]( + methodDescriptor: MethodDescriptor[ReqT, RespT], + callOptions: CallOptions, + channel: Channel): ClientCall[ReqT, RespT] = { + throw new RuntimeException("Blocked") + } + }) + .build() + + val session = SparkSession.builder().client(client).create() + + assertThrows[RuntimeException] { + session.range(10).count() + } + } + private case class TestPackURI( connectionString: String, isCorrect: Boolean,