From 975d8d5ae08b1b0ebc7a4947f0ba8286f06b0c1c Mon Sep 17 00:00:00 2001 From: beliefer Date: Thu, 1 Feb 2024 13:44:17 +0800 Subject: [PATCH] [SPARK-46937][SQL] Improve concurrency performance for FunctionRegistry --- .../catalyst/analysis/FunctionRegistry.scala | 54 +++++++++---------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f37f47c13ed45..bc787be107e5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.ConcurrentHashMap -import scala.collection.mutable +import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import org.apache.spark.SparkUnsupportedOperationException @@ -195,9 +195,8 @@ object FunctionRegistryBase { trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging { - @GuardedBy("this") protected val functionBuilders = - new mutable.HashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)] + new ConcurrentHashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)] // Resolution of the function name is always case insensitive, but the database name // depends on the caller @@ -220,10 +219,10 @@ trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging def internalRegisterFunction( name: FunctionIdentifier, info: ExpressionInfo, - builder: FunctionBuilder): Unit = synchronized { + builder: FunctionBuilder): Unit = { val newFunction = (info, builder) functionBuilders.put(name, newFunction) match { - case Some(previousFunction) if previousFunction != newFunction => + case previousFunction if previousFunction != newFunction => logWarning(log"The function ${MDC(FUNCTION_NAME, name)} replaced a " + log"previously registered function.") case _ => @@ -231,34 +230,25 @@ trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging } override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): T = { - val func = synchronized { - functionBuilders.get(normalizeFuncName(name)).map(_._2).getOrElse { - throw QueryCompilationErrors.unresolvedRoutineError(name, Seq("system.builtin")) - } + val func = Option(functionBuilders.get(normalizeFuncName(name))).map(_._2).getOrElse { + throw QueryCompilationErrors.unresolvedRoutineError(name, Seq("system.builtin")) } func(children) } - override def listFunction(): Seq[FunctionIdentifier] = synchronized { - functionBuilders.iterator.map(_._1).toList - } + override def listFunction(): Seq[FunctionIdentifier] = + functionBuilders.keys().asScala.toSeq - override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = synchronized { - functionBuilders.get(normalizeFuncName(name)).map(_._1) - } + override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = + Option(functionBuilders.get(normalizeFuncName(name))).map(_._1) - override def lookupFunctionBuilder( - name: FunctionIdentifier): Option[FunctionBuilder] = synchronized { - functionBuilders.get(normalizeFuncName(name)).map(_._2) - } + override def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder] = + Option(functionBuilders.get(normalizeFuncName(name))).map(_._2) - override def dropFunction(name: FunctionIdentifier): Boolean = synchronized { - functionBuilders.remove(normalizeFuncName(name)).isDefined - } + override def dropFunction(name: FunctionIdentifier): Boolean = + Option(functionBuilders.remove(normalizeFuncName(name))).isDefined - override def clear(): Unit = synchronized { - functionBuilders.clear() - } + override def clear(): Unit = functionBuilders.clear() } /** @@ -308,7 +298,11 @@ class SimpleFunctionRegistry override def clone(): SimpleFunctionRegistry = synchronized { val registry = new SimpleFunctionRegistry - functionBuilders.iterator.foreach { case (name, (info, builder)) => + val iterator = functionBuilders.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + val name = entry.getKey + val (info, builder) = entry.getValue registry.internalRegisterFunction(name, info, builder) } registry @@ -1032,7 +1026,11 @@ class SimpleTableFunctionRegistry extends SimpleFunctionRegistryBase[LogicalPlan override def clone(): SimpleTableFunctionRegistry = synchronized { val registry = new SimpleTableFunctionRegistry - functionBuilders.iterator.foreach { case (name, (info, builder)) => + val iterator = functionBuilders.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + val name = entry.getKey + val (info, builder) = entry.getValue registry.internalRegisterFunction(name, info, builder) } registry