Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-46937][SQL] Improve concurrency performance for FunctionRegistry #44976

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.spark.sql.catalyst.analysis

import java.util.Locale
import javax.annotation.concurrent.GuardedBy
import java.util.concurrent.ConcurrentHashMap

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

import org.apache.spark.SparkUnsupportedOperationException
Expand Down Expand Up @@ -195,9 +195,8 @@ object FunctionRegistryBase {

trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging {

@GuardedBy("this")
protected val functionBuilders =
new mutable.HashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)]
new ConcurrentHashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)]

// Resolution of the function name is always case insensitive, but the database name
// depends on the caller
Expand All @@ -220,45 +219,36 @@ trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging
def internalRegisterFunction(
name: FunctionIdentifier,
info: ExpressionInfo,
builder: FunctionBuilder): Unit = synchronized {
builder: FunctionBuilder): Unit = {
val newFunction = (info, builder)
functionBuilders.put(name, newFunction) match {
case Some(previousFunction) if previousFunction != newFunction =>
case previousFunction if previousFunction != newFunction =>
logWarning(log"The function ${MDC(FUNCTION_NAME, name)} replaced a " +
log"previously registered function.")
case _ =>
}
}

override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): T = {
val func = synchronized {
functionBuilders.get(normalizeFuncName(name)).map(_._2).getOrElse {
throw QueryCompilationErrors.unresolvedRoutineError(name, Seq("system.builtin"))
}
val func = Option(functionBuilders.get(normalizeFuncName(name))).map(_._2).getOrElse {
throw QueryCompilationErrors.unresolvedRoutineError(name, Seq("system.builtin"))
}
func(children)
}

override def listFunction(): Seq[FunctionIdentifier] = synchronized {
functionBuilders.iterator.map(_._1).toList
}
override def listFunction(): Seq[FunctionIdentifier] =
functionBuilders.keys().asScala.toSeq

override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = synchronized {
functionBuilders.get(normalizeFuncName(name)).map(_._1)
}
override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] =
Option(functionBuilders.get(normalizeFuncName(name))).map(_._1)

override def lookupFunctionBuilder(
name: FunctionIdentifier): Option[FunctionBuilder] = synchronized {
functionBuilders.get(normalizeFuncName(name)).map(_._2)
}
override def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder] =
Option(functionBuilders.get(normalizeFuncName(name))).map(_._2)

override def dropFunction(name: FunctionIdentifier): Boolean = synchronized {
functionBuilders.remove(normalizeFuncName(name)).isDefined
}
override def dropFunction(name: FunctionIdentifier): Boolean =
Option(functionBuilders.remove(normalizeFuncName(name))).isDefined

override def clear(): Unit = synchronized {
functionBuilders.clear()
}
override def clear(): Unit = functionBuilders.clear()
}

/**
Expand Down Expand Up @@ -308,7 +298,11 @@ class SimpleFunctionRegistry

override def clone(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
val iterator = functionBuilders.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val name = entry.getKey
val (info, builder) = entry.getValue
registry.internalRegisterFunction(name, info, builder)
}
registry
Expand Down Expand Up @@ -1032,7 +1026,11 @@ class SimpleTableFunctionRegistry extends SimpleFunctionRegistryBase[LogicalPlan

override def clone(): SimpleTableFunctionRegistry = synchronized {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a problem here. We don't synchronize the write path, how can we safely clone the ConcurrentHashMap?

Copy link
Contributor

@tedyu tedyu Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can introduce a ReentrantReadWriteLock guarding the ConcurrentHashMap.
clone and clear would take the write lock on the ConcurrentHashMap.
The other methods take read lock on the map.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may work, but makes the code complicated. We should only do it if this does make a difference to real-world workloads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, It's safe here.
Because the entrySet() of ConcurrentHashMap is thread safe. We don't need synchronized.

val registry = new SimpleTableFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
val iterator = functionBuilders.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val name = entry.getKey
val (info, builder) = entry.getValue
registry.internalRegisterFunction(name, info, builder)
}
registry
Expand Down