From 30ec6e358536dfb695fcc1b8c3f084acb576d871 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Wed, 1 Nov 2023 21:08:04 -0700 Subject: [PATCH] [SPARK-45742][CORE][CONNECT][MLLIB][PYTHON] Introduce an implicit function for Scala Array to wrap into `immutable.ArraySeq` ### What changes were proposed in this pull request? Currently, we need to use `immutable.ArraySeq.unsafeWrapArray(array)` to wrap an Array into an `immutable.ArraySeq`, which makes the code look bloated. So this PR introduces an implicit function `toImmutableArraySeq` to make it easier for Scala Array to be wrapped into `immutable.ArraySeq`. After this pr, we can use the following way to wrap an array into an `immutable.ArraySeq`: ```scala import org.apache.spark.util.ArrayImplicits._ val dataArray = ... val immutableArraySeq = dataArray.toImmutableArraySeq ``` At the same time, this pr replaces the existing use of `immutable.ArraySeq.unsafeWrapArray(array)` with the new method. On the other hand, this implicit function will be conducive to the progress of work SPARK-45686 and SPARK-45687. ### Why are the changes needed? Makes the code for wrapping a Scala Array into an `immutable.ArraySeq` look less bloated. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43607 from LuciferYang/SPARK-45742. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../apache/spark/util/ArrayImplicits.scala | 36 +++++++++++++ .../org/apache/spark/sql/SparkSession.scala | 4 +- .../client/GrpcExceptionConverter.scala | 4 +- .../connect/planner/SparkConnectPlanner.scala | 27 +++++----- .../spark/sql/connect/utils/ErrorUtils.scala | 32 ++++++------ .../spark/util/ArrayImplicitsSuite.scala | 50 +++++++++++++++++++ .../python/GaussianMixtureModelWrapper.scala | 4 +- .../mllib/api/python/LDAModelWrapper.scala | 8 +-- 8 files changed, 126 insertions(+), 39 deletions(-) create mode 100644 common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala create mode 100644 core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala diff --git a/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala new file mode 100644 index 0000000000000..08997a800c957 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.immutable + +/** + * Implicit methods related to Scala Array. + */ +private[spark] object ArrayImplicits { + + implicit class SparkArrayOps[T](xs: Array[T]) { + + /** + * Wraps an Array[T] as an immutable.ArraySeq[T] without copying. + */ + def toImmutableArraySeq: immutable.ArraySeq[T] = + if (xs eq null) null + else immutable.ArraySeq.unsafeWrapArray(xs) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1cc1c8400fa89..34756f9a440bb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,7 +21,6 @@ import java.net.URI import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicLong, AtomicReference} -import scala.collection.immutable import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag @@ -45,6 +44,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ /** * The entry point to programming Spark with the Dataset and DataFrame API. @@ -248,7 +248,7 @@ class SparkSession private[sql] ( proto.SqlCommand .newBuilder() .setSql(sqlText) - .addAllPosArguments(immutable.ArraySeq.unsafeWrapArray(args.map(lit(_).expr)).asJava))) + .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) // .toBuffer forces that the iterator is consumed and closed val responseSeq = client.execute(plan.build()).toBuffer.toSeq diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 3e53722caeb07..652797bc2e40f 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.client import java.time.DateTimeException -import scala.collection.immutable import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -37,6 +36,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.streaming.StreamingQueryException +import org.apache.spark.util.ArrayImplicits._ /** * GrpcExceptionConverter handles the conversion of StatusRuntimeExceptions into Spark exceptions. @@ -375,7 +375,7 @@ private[client] object GrpcExceptionConverter { FetchErrorDetailsResponse.Error .newBuilder() .setMessage(message) - .addAllErrorTypeHierarchy(immutable.ArraySeq.unsafeWrapArray(classes).asJava) + .addAllErrorTypeHierarchy(classes.toImmutableArraySeq.asJava) .build())) } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index ec57909ad144e..018e293795e9d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connect.planner -import scala.collection.immutable import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.util.Try @@ -80,6 +79,7 @@ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.CacheId +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils final case class InvalidCommandInput( @@ -3184,9 +3184,9 @@ class SparkConnectPlanner( case StreamingQueryManagerCommand.CommandCase.ACTIVE => val active_queries = session.streams.active respBuilder.getActiveBuilder.addAllActiveQueries( - immutable.ArraySeq - .unsafeWrapArray(active_queries - .map(query => buildStreamingQueryInstance(query))) + active_queries + .map(query => buildStreamingQueryInstance(query)) + .toImmutableArraySeq .asJava) case StreamingQueryManagerCommand.CommandCase.GET_QUERY => @@ -3265,15 +3265,16 @@ class SparkConnectPlanner( .setGetResourcesCommandResult( proto.GetResourcesCommandResult .newBuilder() - .putAllResources(session.sparkContext.resources.view - .mapValues(resource => - proto.ResourceInformation - .newBuilder() - .setName(resource.name) - .addAllAddresses(immutable.ArraySeq.unsafeWrapArray(resource.addresses).asJava) - .build()) - .toMap - .asJava) + .putAllResources( + session.sparkContext.resources.view + .mapValues(resource => + proto.ResourceInformation + .newBuilder() + .setName(resource.name) + .addAllAddresses(resource.addresses.toImmutableArraySeq.asJava) + .build()) + .toMap + .asJava) .build()) .build()) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 837ee5a00227c..744fa3c8aa1a4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.connect.utils import java.util.UUID import scala.annotation.tailrec -import scala.collection.immutable import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ @@ -43,6 +42,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.{ExecuteEventsManager, SessionHolder, SessionKey, SparkConnectService} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ArrayImplicits._ private[connect] object ErrorUtils extends Logging { @@ -91,21 +91,21 @@ private[connect] object ErrorUtils extends Logging { if (serverStackTraceEnabled) { builder.addAllStackTrace( - immutable.ArraySeq - .unsafeWrapArray(currentError.getStackTrace - .map { stackTraceElement => - val stackTraceBuilder = FetchErrorDetailsResponse.StackTraceElement - .newBuilder() - .setDeclaringClass(stackTraceElement.getClassName) - .setMethodName(stackTraceElement.getMethodName) - .setLineNumber(stackTraceElement.getLineNumber) - - if (stackTraceElement.getFileName != null) { - stackTraceBuilder.setFileName(stackTraceElement.getFileName) - } - - stackTraceBuilder.build() - }) + currentError.getStackTrace + .map { stackTraceElement => + val stackTraceBuilder = FetchErrorDetailsResponse.StackTraceElement + .newBuilder() + .setDeclaringClass(stackTraceElement.getClassName) + .setMethodName(stackTraceElement.getMethodName) + .setLineNumber(stackTraceElement.getLineNumber) + + if (stackTraceElement.getFileName != null) { + stackTraceBuilder.setFileName(stackTraceElement.getFileName) + } + + stackTraceBuilder.build() + } + .toImmutableArraySeq .asJava) } diff --git a/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala new file mode 100644 index 0000000000000..135af550c4b39 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.immutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ArrayImplicits._ + +class ArrayImplicitsSuite extends SparkFunSuite { + + test("Int Array") { + val data = Array(1, 2, 3) + val arraySeq = data.toImmutableArraySeq + assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofInt]) + assert(arraySeq.length === 3) + assert(arraySeq.unsafeArray.sameElements(data)) + } + + test("TestClass Array") { + val data = Array(TestClass(1), TestClass(2), TestClass(3)) + val arraySeq = data.toImmutableArraySeq + assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofRef[TestClass]]) + assert(arraySeq.length === 3) + assert(arraySeq.unsafeArray.sameElements(data)) + } + + test("Null Array") { + val data: Array[Int] = null + val arraySeq = data.toImmutableArraySeq + assert(arraySeq == null) + } + + case class TestClass(i: Int) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index 1eed97a8d4f65..2f3f396730be2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.api.python -import scala.collection.immutable import scala.jdk.CollectionConverters._ import org.apache.spark.SparkContext import org.apache.spark.mllib.clustering.GaussianMixtureModel import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.util.ArrayImplicits._ /** * Wrapper around GaussianMixtureModel to provide helper methods in Python @@ -38,7 +38,7 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { val modelGaussians = model.gaussians.map { gaussian => Array[Any](gaussian.mu, gaussian.sigma) } - SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(modelGaussians).asJava) + SerDe.dumps(modelGaussians.toImmutableArraySeq.asJava) } def predictSoft(point: Vector): Vector = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala index b919b0a8c3f2e..6a6c6cf6bcfb3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala @@ -16,12 +16,12 @@ */ package org.apache.spark.mllib.api.python -import scala.collection.immutable import scala.jdk.CollectionConverters._ import org.apache.spark.SparkContext import org.apache.spark.mllib.clustering.LDAModel import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.util.ArrayImplicits._ /** * Wrapper around LDAModel to provide helper methods in Python @@ -36,11 +36,11 @@ private[python] class LDAModelWrapper(model: LDAModel) { def describeTopics(maxTermsPerTopic: Int): Array[Byte] = { val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) => - val jTerms = immutable.ArraySeq.unsafeWrapArray(terms).asJava - val jTermWeights = immutable.ArraySeq.unsafeWrapArray(termWeights).asJava + val jTerms = terms.toImmutableArraySeq.asJava + val jTermWeights = termWeights.toImmutableArraySeq.asJava Array[Any](jTerms, jTermWeights) } - SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(topics).asJava) + SerDe.dumps(topics.toImmutableArraySeq.asJava) } def save(sc: SparkContext, path: String): Unit = model.save(sc, path)