Skip to content

Commit

Permalink
Support map parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxGekk committed Aug 31, 2023
1 parent 81e7189 commit 91bc0ce
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Literal, SubqueryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, LeafExpression, Literal, MapFromArrays, SubqueryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
Expand Down Expand Up @@ -96,7 +96,11 @@ case class PosParameterizedQuery(child: LogicalPlan, args: Array[Expression])
*/
object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
private def checkArgs(args: Iterable[(String, Expression)]): Unit = {
args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) =>
def isNotAllowed(expr: Expression): Boolean = expr.exists {
case _: Literal | _: CreateArray | _: MapFromArrays => false
case _ => true
}
args.find(arg => isNotAllowed(arg._2)).foreach { case (name, expr) =>
expr.failAnalysis(
errorClass = "INVALID_SQL_ARG",
messageParameters = Map("name" -> name))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package org.apache.spark.sql

import java.time.{Instant, LocalDate, LocalDateTime, ZoneId}

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.{array, lit, map_from_arrays}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

Expand Down Expand Up @@ -529,4 +530,38 @@ class ParametersSuite extends QueryTest with SharedSparkSession {
spark.sql("SELECT ?[?][?]", Array(Array(Array(1f, 2f), Array.empty[Float], Array(3f)), 0, 1)),
Row(2f))
}

test("SPARK-XXXXX: maps as parameters") {
def fromArr(keys: Array[_], values: Array[_]): Column = {
map_from_arrays(Column(Literal(keys)), Column(Literal(values)))
}
checkAnswer(
spark.sql("SELECT map_contains_key(:mapParam, 0)",
Map("mapParam" -> fromArr(Array.empty[Int], Array.empty[String]))),
Row(false))
checkAnswer(
spark.sql("SELECT map_contains_key(?, 'a')",
Array(fromArr(Array.empty[String], Array.empty[Double]))),
Row(false))
checkAnswer(
spark.sql("SELECT element_at(:mapParam, 'a')",
Map("mapParam" -> fromArr(Array("a"), Array(0)))),
Row(0))
checkAnswer(
spark.sql("SELECT element_at(?, 'a')", Array(fromArr(Array("a"), Array(0)))),
Row(0))
checkAnswer(
spark.sql("SELECT :m[10]", Map("m" -> fromArr(Array(10, 20, 30), Array(0, 1, 2)))),
Row(0))
checkAnswer(
spark.sql("SELECT ?[?]", Array(fromArr(Array(1f, 2f, 3f), Array(1, 2, 3)), 2f)),
Row(2))
checkAnswer(
spark.sql("SELECT :m['a'][1]",
Map("m" ->
map_from_arrays(
Column(Literal(Array("a"))),
array(map_from_arrays(Column(Literal(Array(1))), Column(Literal(Array(2)))))))),
Row(2))
}
}

0 comments on commit 91bc0ce

Please sign in to comment.