Skip to content

Commit

Permalink
[SPARK-44121][CONNECT][TESTS] Renable Arrow-based connect tests in Ja…
Browse files Browse the repository at this point in the history
…va 21

### What changes were proposed in this pull request?

This PR aims to re-enable Arrow-based connect tests in Java 21.
This depends on apache#42181.

### Why are the changes needed?

To have Java 21 test coverage.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

```
$ java -version
openjdk version "21-ea" 2023-09-19
OpenJDK Runtime Environment (build 21-ea+32-2482)
OpenJDK 64-Bit Server VM (build 21-ea+32-2482, mixed mode, sharing)

$ build/sbt "connect/test" -Phive
...
[info] Run completed in 14 seconds, 136 milliseconds.
[info] Total number of tests run: 858
[info] Suites: completed 20, aborted 0
[info] Tests: succeeded 858, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
[success] Total time: 44 s, completed Aug 23, 2023, 9:42:53 PM

$ build/sbt "connect-client-jvm/test" -Phive
...
[info] Run completed in 1 minute, 24 seconds.
[info] Total number of tests run: 1220
[info] Suites: completed 24, aborted 0
[info] Tests: succeeded 1220, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
[info] Passed: Total 1222, Failed 0, Errors 0, Passed 1222
```

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#42643 from dongjoon-hyun/SPARK-44121.

Authored-by: Dongjoon Hyun <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
dongjoon-hyun committed Aug 24, 2023
1 parent 897b87f commit a824a6d
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import java.util.concurrent.{Executors, Semaphore, TimeUnit}
import scala.util.Properties

import org.apache.commons.io.output.ByteArrayOutputStream
import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
Expand Down Expand Up @@ -51,29 +50,26 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
}

override def beforeAll(): Unit = {
// TODO(SPARK-44121) Remove this check condition
if (SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) {
super.beforeAll()
ammoniteOut = new ByteArrayOutputStream()
testSuiteOut = new PipedOutputStream()
// Connect the `testSuiteOut` and `ammoniteIn` pipes
ammoniteIn = new PipedInputStream(testSuiteOut)
errorStream = new ByteArrayOutputStream()

val args = Array("--port", serverPort.toString)
val task = new Runnable {
override def run(): Unit = {
ConnectRepl.doMain(
args = args,
semaphore = Some(semaphore),
inputStream = ammoniteIn,
outputStream = ammoniteOut,
errorStream = errorStream)
}
super.beforeAll()
ammoniteOut = new ByteArrayOutputStream()
testSuiteOut = new PipedOutputStream()
// Connect the `testSuiteOut` and `ammoniteIn` pipes
ammoniteIn = new PipedInputStream(testSuiteOut)
errorStream = new ByteArrayOutputStream()

val args = Array("--port", serverPort.toString)
val task = new Runnable {
override def run(): Unit = {
ConnectRepl.doMain(
args = args,
semaphore = Some(semaphore),
inputStream = ammoniteIn,
outputStream = ammoniteOut,
errorStream = errorStream)
}

executorService.submit(task)
}

executorService.submit(task)
}

override def afterAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ import java.util.concurrent.TimeUnit

import scala.io.Source

import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.scalactic.source.Position
import org.scalatest.{BeforeAndAfterAll, Tag}
import org.scalatest.BeforeAndAfterAll
import sys.process._

import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -180,44 +178,41 @@ trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll {
protected lazy val serverPort: Int = port

override def beforeAll(): Unit = {
// TODO(SPARK-44121) Remove this check condition
if (SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) {
super.beforeAll()
SparkConnectServerUtils.start()
spark = SparkSession
.builder()
.client(SparkConnectClient.builder().port(serverPort).build())
.create()

// Retry and wait for the server to start
val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min
var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff
var success = false
val error = new RuntimeException(s"Failed to start the test server on port $serverPort.")

while (!success && System.nanoTime() < stop) {
try {
// Run a simple query to verify the server is really up and ready
val result = spark
.sql("select val from (values ('Hello'), ('World')) as t(val)")
.collect()
assert(result.length == 2)
success = true
debug("Spark Connect Server is up.")
} catch {
// ignored the error
case e: Throwable =>
error.addSuppressed(e)
Thread.sleep(sleepInternalMs)
sleepInternalMs *= 2
}
super.beforeAll()
SparkConnectServerUtils.start()
spark = SparkSession
.builder()
.client(SparkConnectClient.builder().port(serverPort).build())
.create()

// Retry and wait for the server to start
val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min
var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff
var success = false
val error = new RuntimeException(s"Failed to start the test server on port $serverPort.")

while (!success && System.nanoTime() < stop) {
try {
// Run a simple query to verify the server is really up and ready
val result = spark
.sql("select val from (values ('Hello'), ('World')) as t(val)")
.collect()
assert(result.length == 2)
success = true
debug("Spark Connect Server is up.")
} catch {
// ignored the error
case e: Throwable =>
error.addSuppressed(e)
Thread.sleep(sleepInternalMs)
sleepInternalMs *= 2
}
}

// Throw error if failed
if (!success) {
debug(error)
throw error
}
// Throw error if failed
if (!success) {
debug(error)
throw error
}
}

Expand All @@ -230,17 +225,4 @@ trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll {
spark = null
super.afterAll()
}

/**
* SPARK-44259: override test function to skip `RemoteSparkSession-based` tests as default, we
* should delete this function after SPARK-44121 is completed.
*/
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
pos: Position): Unit = {
super.test(testName, testTags: _*) {
// TODO(SPARK-44121) Re-enable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
testFun
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._

import com.google.protobuf.ByteString
import io.grpc.stub.StreamObserver
import org.apache.commons.lang3.{JavaVersion, SystemUtils}

import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto
Expand Down Expand Up @@ -479,8 +478,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
}

test("transform LocalRelation") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val rows = (0 until 10).map { i =>
InternalRow(i, UTF8String.fromString(s"str-$i"), InternalRow(i))
}
Expand Down Expand Up @@ -582,8 +579,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
}

test("transform UnresolvedStar and ExpressionString") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val sql =
"SELECT * FROM VALUES (1,'spark',1), (2,'hadoop',2), (3,'kafka',3) AS tab(id, name, value)"
val input = proto.Relation
Expand Down Expand Up @@ -620,8 +615,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
}

test("transform UnresolvedStar with target field") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val rows = (0 until 10).map { i =>
InternalRow(InternalRow(InternalRow(i, i + 1)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.nio.file.{Files, Paths}
import scala.collection.JavaConverters._

import com.google.protobuf.ByteString
import org.apache.commons.lang3.{JavaVersion, SystemUtils}

import org.apache.spark.{SparkClassNotFoundException, SparkIllegalArgumentException}
import org.apache.spark.connect.proto
Expand Down Expand Up @@ -695,8 +694,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}

test("WriteTo with create") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
withTable("testcat.table_name") {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)

Expand Down Expand Up @@ -724,8 +721,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}

test("WriteTo with create and using") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val defaultOwnership = Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName())
withTable("testcat.table_name") {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
Expand Down Expand Up @@ -763,8 +758,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}

test("WriteTo with append") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
withTable("testcat.table_name") {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)

Expand Down Expand Up @@ -796,8 +789,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}

test("WriteTo with overwrite") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
withTable("testcat.table_name") {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)

Expand Down Expand Up @@ -851,8 +842,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}

test("WriteTo with overwritePartitions") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
withTable("testcat.table_name") {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import io.grpc.stub.StreamObserver
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{BigIntVector, Float8Vector}
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.mockito.Mockito.when
import org.scalatest.Tag
import org.scalatestplus.mockito.MockitoSugar
Expand Down Expand Up @@ -157,8 +156,6 @@ class SparkConnectServiceSuite

test("SPARK-41224: collect data using arrow") {
withEvents { verifyEvents =>
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val instance = new SparkConnectService(false)
val connect = new MockRemoteSession()
val context = proto.UserContext
Expand Down Expand Up @@ -246,8 +243,6 @@ class SparkConnectServiceSuite

test("SPARK-44776: LocalTableScanExec") {
withEvents { verifyEvents =>
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val instance = new SparkConnectService(false)
val connect = new MockRemoteSession()
val context = proto.UserContext
Expand Down Expand Up @@ -319,8 +314,6 @@ class SparkConnectServiceSuite
// Set 10 KiB as the batch size limit
val batchSize = 10 * 1024
withSparkConf("spark.connect.grpc.arrow.maxBatchSize" -> batchSize.toString) {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val instance = new SparkConnectService(false)
val connect = new MockRemoteSession()
val context = proto.UserContext
Expand Down Expand Up @@ -742,8 +735,6 @@ class SparkConnectServiceSuite
}

test("Test observe response") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
withTable("test") {
spark.sql("""
| CREATE TABLE test (col1 INT, col2 STRING)
Expand Down

0 comments on commit a824a6d

Please sign in to comment.