From a824a6de89fdd2ecc119a9bb48bca64da5db72bd Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 24 Aug 2023 00:53:33 -0700 Subject: [PATCH] [SPARK-44121][CONNECT][TESTS] Renable Arrow-based connect tests in Java 21 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR aims to re-enable Arrow-based connect tests in Java 21. This depends on #42181. ### Why are the changes needed? To have Java 21 test coverage. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ``` $ java -version openjdk version "21-ea" 2023-09-19 OpenJDK Runtime Environment (build 21-ea+32-2482) OpenJDK 64-Bit Server VM (build 21-ea+32-2482, mixed mode, sharing) $ build/sbt "connect/test" -Phive ... [info] Run completed in 14 seconds, 136 milliseconds. [info] Total number of tests run: 858 [info] Suites: completed 20, aborted 0 [info] Tests: succeeded 858, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. [success] Total time: 44 s, completed Aug 23, 2023, 9:42:53 PM $ build/sbt "connect-client-jvm/test" -Phive ... [info] Run completed in 1 minute, 24 seconds. [info] Total number of tests run: 1220 [info] Suites: completed 24, aborted 0 [info] Tests: succeeded 1220, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. [info] Passed: Total 1222, Failed 0, Errors 0, Passed 1222 ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42643 from dongjoon-hyun/SPARK-44121. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/sql/application/ReplE2ESuite.scala | 40 ++++----- .../client/util/RemoteSparkSession.scala | 86 ++++++++----------- .../planner/SparkConnectPlannerSuite.scala | 7 -- .../planner/SparkConnectProtoSuite.scala | 11 --- .../planner/SparkConnectServiceSuite.scala | 9 -- 5 files changed, 52 insertions(+), 101 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index b297123614750..5a909ab8b4178 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -23,7 +23,6 @@ import java.util.concurrent.{Executors, Semaphore, TimeUnit} import scala.util.Properties import org.apache.commons.io.output.ByteArrayOutputStream -import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession} @@ -51,29 +50,26 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { } override def beforeAll(): Unit = { - // TODO(SPARK-44121) Remove this check condition - if (SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) { - super.beforeAll() - ammoniteOut = new ByteArrayOutputStream() - testSuiteOut = new PipedOutputStream() - // Connect the `testSuiteOut` and `ammoniteIn` pipes - ammoniteIn = new PipedInputStream(testSuiteOut) - errorStream = new ByteArrayOutputStream() - - val args = Array("--port", serverPort.toString) - val task = new Runnable { - override def run(): Unit = { - ConnectRepl.doMain( - args = args, - semaphore = Some(semaphore), - inputStream = ammoniteIn, - outputStream = ammoniteOut, - errorStream = errorStream) - } + super.beforeAll() + ammoniteOut = new ByteArrayOutputStream() + testSuiteOut = new PipedOutputStream() + // Connect the `testSuiteOut` and `ammoniteIn` pipes + ammoniteIn = new PipedInputStream(testSuiteOut) + errorStream = new ByteArrayOutputStream() + + val args = Array("--port", serverPort.toString) + val task = new Runnable { + override def run(): Unit = { + ConnectRepl.doMain( + args = args, + semaphore = Some(semaphore), + inputStream = ammoniteIn, + outputStream = ammoniteOut, + errorStream = errorStream) } - - executorService.submit(task) } + + executorService.submit(task) } override def afterAll(): Unit = { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala index f14109e49b57d..33540bf498535 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala @@ -21,9 +21,7 @@ import java.util.concurrent.TimeUnit import scala.io.Source -import org.apache.commons.lang3.{JavaVersion, SystemUtils} -import org.scalactic.source.Position -import org.scalatest.{BeforeAndAfterAll, Tag} +import org.scalatest.BeforeAndAfterAll import sys.process._ import org.apache.spark.sql.SparkSession @@ -180,44 +178,41 @@ trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll { protected lazy val serverPort: Int = port override def beforeAll(): Unit = { - // TODO(SPARK-44121) Remove this check condition - if (SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) { - super.beforeAll() - SparkConnectServerUtils.start() - spark = SparkSession - .builder() - .client(SparkConnectClient.builder().port(serverPort).build()) - .create() - - // Retry and wait for the server to start - val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min - var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff - var success = false - val error = new RuntimeException(s"Failed to start the test server on port $serverPort.") - - while (!success && System.nanoTime() < stop) { - try { - // Run a simple query to verify the server is really up and ready - val result = spark - .sql("select val from (values ('Hello'), ('World')) as t(val)") - .collect() - assert(result.length == 2) - success = true - debug("Spark Connect Server is up.") - } catch { - // ignored the error - case e: Throwable => - error.addSuppressed(e) - Thread.sleep(sleepInternalMs) - sleepInternalMs *= 2 - } + super.beforeAll() + SparkConnectServerUtils.start() + spark = SparkSession + .builder() + .client(SparkConnectClient.builder().port(serverPort).build()) + .create() + + // Retry and wait for the server to start + val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min + var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff + var success = false + val error = new RuntimeException(s"Failed to start the test server on port $serverPort.") + + while (!success && System.nanoTime() < stop) { + try { + // Run a simple query to verify the server is really up and ready + val result = spark + .sql("select val from (values ('Hello'), ('World')) as t(val)") + .collect() + assert(result.length == 2) + success = true + debug("Spark Connect Server is up.") + } catch { + // ignored the error + case e: Throwable => + error.addSuppressed(e) + Thread.sleep(sleepInternalMs) + sleepInternalMs *= 2 } + } - // Throw error if failed - if (!success) { - debug(error) - throw error - } + // Throw error if failed + if (!success) { + debug(error) + throw error } } @@ -230,17 +225,4 @@ trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll { spark = null super.afterAll() } - - /** - * SPARK-44259: override test function to skip `RemoteSparkSession-based` tests as default, we - * should delete this function after SPARK-44121 is completed. - */ - override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit - pos: Position): Unit = { - super.test(testName, testTags: _*) { - // TODO(SPARK-44121) Re-enable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) - testFun - } - } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 39b4f40215dc3..0caa02a0b6112 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._ import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver -import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto @@ -479,8 +478,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("transform LocalRelation") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) val rows = (0 until 10).map { i => InternalRow(i, UTF8String.fromString(s"str-$i"), InternalRow(i)) } @@ -582,8 +579,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("transform UnresolvedStar and ExpressionString") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) val sql = "SELECT * FROM VALUES (1,'spark',1), (2,'hadoop',2), (3,'kafka',3) AS tab(id, name, value)" val input = proto.Relation @@ -620,8 +615,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("transform UnresolvedStar with target field") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) val rows = (0 until 10).map { i => InternalRow(InternalRow(InternalRow(i, i + 1))) } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 63b6f775d7b14..0c12bf5e625a9 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -21,7 +21,6 @@ import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ import com.google.protobuf.ByteString -import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.apache.spark.{SparkClassNotFoundException, SparkIllegalArgumentException} import org.apache.spark.connect.proto @@ -695,8 +694,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } test("WriteTo with create") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("testcat.table_name") { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) @@ -724,8 +721,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } test("WriteTo with create and using") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) val defaultOwnership = Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName()) withTable("testcat.table_name") { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) @@ -763,8 +758,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } test("WriteTo with append") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("testcat.table_name") { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) @@ -796,8 +789,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } test("WriteTo with overwrite") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("testcat.table_name") { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) @@ -851,8 +842,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } test("WriteTo with overwritePartitions") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("testcat.table_name") { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 74649e15e9eba..90c9d13def616 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -29,7 +29,6 @@ import io.grpc.stub.StreamObserver import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{BigIntVector, Float8Vector} import org.apache.arrow.vector.ipc.ArrowStreamReader -import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.mockito.Mockito.when import org.scalatest.Tag import org.scalatestplus.mockito.MockitoSugar @@ -157,8 +156,6 @@ class SparkConnectServiceSuite test("SPARK-41224: collect data using arrow") { withEvents { verifyEvents => - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) val instance = new SparkConnectService(false) val connect = new MockRemoteSession() val context = proto.UserContext @@ -246,8 +243,6 @@ class SparkConnectServiceSuite test("SPARK-44776: LocalTableScanExec") { withEvents { verifyEvents => - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) val instance = new SparkConnectService(false) val connect = new MockRemoteSession() val context = proto.UserContext @@ -319,8 +314,6 @@ class SparkConnectServiceSuite // Set 10 KiB as the batch size limit val batchSize = 10 * 1024 withSparkConf("spark.connect.grpc.arrow.maxBatchSize" -> batchSize.toString) { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) val instance = new SparkConnectService(false) val connect = new MockRemoteSession() val context = proto.UserContext @@ -742,8 +735,6 @@ class SparkConnectServiceSuite } test("Test observe response") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("test") { spark.sql(""" | CREATE TABLE test (col1 INT, col2 STRING)