Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v1] Refactor representation of set ops #1538

Merged
merged 9 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 166 additions & 170 deletions partiql-ast/api/partiql-ast.api

Large diffs are not rendered by default.

91 changes: 49 additions & 42 deletions partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.partiql.ast.Let
import org.partiql.ast.OnConflict
import org.partiql.ast.OrderBy
import org.partiql.ast.Path
import org.partiql.ast.QueryBody
import org.partiql.ast.Returning
import org.partiql.ast.Select
import org.partiql.ast.SetOp
Expand Down Expand Up @@ -621,26 +622,6 @@ private class AstTranslator(val metas: Map<String, MetaContainer>) : AstBaseVisi
call("date_diff", operands, metas)
}

override fun visitExprBagOp(node: Expr.BagOp, ctx: Ctx) = translate(node) { metas ->
val lhs = visitExpr(node.lhs, ctx)
val rhs = visitExpr(node.rhs, ctx)
val op = when (node.outer) {
true -> when (node.type.type) {
SetOp.Type.UNION -> outerUnion()
SetOp.Type.INTERSECT -> outerIntersect()
SetOp.Type.EXCEPT -> outerExcept()
}
else -> when (node.type.type) {
SetOp.Type.UNION -> union()
SetOp.Type.INTERSECT -> intersect()
SetOp.Type.EXCEPT -> except()
}
}
val setq = node.type.setq?.toLegacySetQuantifier() ?: distinct()
val operands = listOf(lhs, rhs)
bagOp(op, setq, operands, metas)
}

override fun visitExprMatch(node: Expr.Match, ctx: Ctx) = translate(node) { metas ->
val expr = visitExpr(node.expr, ctx)
val match = visitGraphMatch(node.pattern, ctx)
Expand Down Expand Up @@ -673,36 +654,62 @@ private class AstTranslator(val metas: Map<String, MetaContainer>) : AstBaseVisi
/**
* SELECT-FROM-WHERE
*/

override fun visitExprSFW(node: Expr.SFW, ctx: Ctx) = translate(node) { metas ->
var setq = when (val s = node.select) {
is Select.Pivot -> null
is Select.Project -> s.setq?.toLegacySetQuantifier()
is Select.Star -> s.setq?.toLegacySetQuantifier()
is Select.Value -> s.setq?.toLegacySetQuantifier()
}
// Legacy AST removes (setq (all))
if (setq != null && setq is PartiqlAst.SetQuantifier.All) {
setq = null
}
val project = visitSelect(node.select, ctx)
val from = visitFrom(node.from, ctx)
val exclude = node.exclude?.let { visitExclude(it, ctx) }
val fromLet = node.let?.let { visitLet(it, ctx) }
val where = node.where?.let { visitExpr(it, ctx) }
val groupBy = node.groupBy?.let { visitGroupBy(it, ctx) }
val having = node.having?.let { visitExpr(it, ctx) }
override fun visitExprQuerySet(node: Expr.QuerySet, ctx: Ctx) = translate(node) { metas ->
val orderBy = node.orderBy?.let { visitOrderBy(it, ctx) }
val limit = node.limit?.let { visitExpr(it, ctx) }
val offset = node.offset?.let { visitExpr(it, ctx) }
select(setq, project, exclude, from, fromLet, where, groupBy, having, orderBy, limit, offset, metas)
when (val body = node.body) {
is QueryBody.SFW -> {
var setq = when (val s = body.select) {
is Select.Pivot -> null
is Select.Project -> s.setq?.toLegacySetQuantifier()
is Select.Star -> s.setq?.toLegacySetQuantifier()
is Select.Value -> s.setq?.toLegacySetQuantifier()
}
// Legacy AST removes (setq (all))
if (setq != null && setq is PartiqlAst.SetQuantifier.All) {
setq = null
}
val project = visitSelect(body.select, ctx)
val from = visitFrom(body.from, ctx)
val exclude = body.exclude?.let { visitExclude(it, ctx) }
val fromLet = body.let?.let { visitLet(it, ctx) }
val where = body.where?.let { visitExpr(it, ctx) }
val groupBy = body.groupBy?.let { visitGroupBy(it, ctx) }
val having = body.having?.let { visitExpr(it, ctx) }
select(setq, project, exclude, from, fromLet, where, groupBy, having, orderBy, limit, offset, metas)
}
is QueryBody.SetOp -> {
val lhs = visitExpr(body.lhs, ctx)
val rhs = visitExpr(body.rhs, ctx)
val outer = body.isOuter
val op = when (body.type.type) {
SetOp.Type.UNION -> if (outer) {
outerUnion()
} else {
union()
}
SetOp.Type.INTERSECT -> if (outer) {
outerIntersect()
} else {
intersect()
}
SetOp.Type.EXCEPT -> if (outer) {
outerExcept()
} else {
except()
}
}
val setq = body.type.setq?.toLegacySetQuantifier() ?: distinct()
val operands = listOf(lhs, rhs)
bagOp(op, setq, operands, metas)
}
}
}

/**
* UNSUPPORTED in legacy AST
*/
override fun visitExprSFWSetOp(node: Expr.SFW.SetOp, ctx: Ctx) = defaultVisit(node, ctx)

override fun visitSelect(node: Select, ctx: Ctx) = super.visitSelect(node, ctx) as PartiqlAst.Projection

override fun visitSelectStar(node: Select.Star, ctx: Ctx) = translate(node) { metas ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package org.partiql.ast.normalize
import org.partiql.ast.AstNode
import org.partiql.ast.Expr
import org.partiql.ast.From
import org.partiql.ast.QueryBody
import org.partiql.ast.Statement
import org.partiql.ast.fromJoin
import org.partiql.ast.helpers.toBinder
Expand All @@ -32,7 +33,7 @@ internal object NormalizeFromSource : AstPass {
private object Visitor : AstRewriter<Int>() {

// Each SFW starts the ctx count again.
override fun visitExprSFW(node: Expr.SFW, ctx: Int): AstNode = super.visitExprSFW(node, 0)
override fun visitQueryBodySFW(node: QueryBody.SFW, ctx: Int): AstNode = super.visitQueryBodySFW(node, 0)

override fun visitStatementDMLBatchLegacy(node: Statement.DML.BatchLegacy, ctx: Int): AstNode =
super.visitStatementDMLBatchLegacy(node, 0)
Expand Down
78 changes: 36 additions & 42 deletions partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.partiql.ast.Identifier
import org.partiql.ast.Let
import org.partiql.ast.OrderBy
import org.partiql.ast.Path
import org.partiql.ast.QueryBody
import org.partiql.ast.Select
import org.partiql.ast.SetOp
import org.partiql.ast.SetQuantifier
Expand Down Expand Up @@ -57,10 +58,10 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
* @param head
*/
public open fun visitExprWrapped(node: Expr, head: SqlBlock): SqlBlock = when (node) {
is Expr.SFW -> {
is Expr.QuerySet -> {
var h = head
h = h concat "("
h = visitExprSFW(node, h)
h = visitExpr(node, h)
h = h concat ")"
h
}
Expand Down Expand Up @@ -552,33 +553,22 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
return h
}

override fun visitExprBagOp(node: Expr.BagOp, head: SqlBlock): SqlBlock {
// [OUTER] [UNION|INTERSECT|EXCEPT] [ALL|DISTINCT]
val op = mutableListOf<String>()
when (node.outer) {
true -> op.add("OUTER")
else -> {}
}
when (node.type.type) {
SetOp.Type.UNION -> op.add("UNION")
SetOp.Type.INTERSECT -> op.add("INTERSECT")
SetOp.Type.EXCEPT -> op.add("EXCEPT")
}
when (node.type.setq) {
SetQuantifier.ALL -> op.add("ALL")
SetQuantifier.DISTINCT -> op.add("DISTINCT")
null -> {}
}
override fun visitExprQuerySet(node: Expr.QuerySet, head: SqlBlock): SqlBlock {
var h = head
h = visitExprWrapped(node.lhs, h)
h = h concat r(" ${op.joinToString(" ")} ")
h = visitExprWrapped(node.rhs, h)
// visit body (SFW or other SQL set op)
h = visit(node.body, h)
// ORDER BY
h = if (node.orderBy != null) visitOrderBy(node.orderBy, h concat r(" ")) else h
// LIMIT
h = if (node.limit != null) visitExprWrapped(node.limit, h concat r(" LIMIT ")) else h
// OFFSET
h = if (node.offset != null) visitExprWrapped(node.offset, h concat r(" OFFSET ")) else h
return h
}

// SELECT-FROM-WHERE

override fun visitExprSFW(node: Expr.SFW, head: SqlBlock): SqlBlock {
override fun visitQueryBodySFW(node: QueryBody.SFW, head: SqlBlock): SqlBlock {
var h = head
// SELECT
h = visit(node.select, h)
Expand All @@ -594,14 +584,29 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
h = if (node.groupBy != null) visitGroupBy(node.groupBy, h concat r(" ")) else h
// HAVING
h = if (node.having != null) visitExprWrapped(node.having, h concat r(" HAVING ")) else h
// SET OP
h = if (node.setOp != null) visitExprSFWSetOp(node.setOp, h concat r(" ")) else h
// ORDER BY
h = if (node.orderBy != null) visitOrderBy(node.orderBy, h concat r(" ")) else h
// LIMIT
h = if (node.limit != null) visitExprWrapped(node.limit, h concat r(" LIMIT ")) else h
// OFFSET
h = if (node.offset != null) visitExprWrapped(node.offset, h concat r(" OFFSET ")) else h
return h
}

override fun visitQueryBodySetOp(node: QueryBody.SetOp, head: SqlBlock): SqlBlock {
val op = mutableListOf<String>()
when (node.isOuter) {
true -> op.add("OUTER")
else -> {}
}
when (node.type.type) {
SetOp.Type.UNION -> op.add("UNION")
SetOp.Type.INTERSECT -> op.add("INTERSECT")
SetOp.Type.EXCEPT -> op.add("EXCEPT")
}
when (node.type.setq) {
SetQuantifier.ALL -> op.add("ALL")
SetQuantifier.DISTINCT -> op.add("DISTINCT")
null -> {}
}
var h = head
h = visitExprWrapped(node.lhs, h)
h = h concat r(" ${op.joinToString(" ")} ")
h = visitExprWrapped(node.rhs, h)
return h
}

Expand Down Expand Up @@ -736,17 +741,6 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
return head concat r(op)
}

override fun visitExprSFWSetOp(node: Expr.SFW.SetOp, head: SqlBlock): SqlBlock {
var h = head
h = visitSetOp(node.type, h)
h = h concat r(" ")
h = h concat r("(")
val subquery = visitExprSFW(node.operand, SqlBlock.Nil)
h = h concat SqlBlock.Nest(subquery)
h = h concat r(")")
return h
}

// ORDER BY

override fun visitOrderBy(node: OrderBy, head: SqlBlock): SqlBlock = head concat list("ORDER BY ", "") { node.sorts }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.partiql.ast.Identifier
import org.partiql.ast.Let
import org.partiql.ast.OrderBy
import org.partiql.ast.Path
import org.partiql.ast.QueryBody
import org.partiql.ast.Select
import org.partiql.ast.SetOp
import org.partiql.ast.SetQuantifier
Expand Down Expand Up @@ -75,10 +76,10 @@ internal abstract class InternalSqlDialect : AstBaseVisitor<InternalSqlBlock, In
* @param tail
*/
open fun visitExprWrapped(node: Expr, tail: InternalSqlBlock): InternalSqlBlock = when (node) {
is Expr.SFW -> {
is Expr.QuerySet -> {
var t = tail
t = t concat "("
t = visitExprSFW(node, t)
t = visit(node, t)
t = t concat ")"
t
}
Expand Down Expand Up @@ -577,10 +578,43 @@ internal abstract class InternalSqlDialect : AstBaseVisitor<InternalSqlBlock, In
return t
}

override fun visitExprBagOp(node: Expr.BagOp, tail: InternalSqlBlock): InternalSqlBlock {
// [OUTER] [UNION|INTERSECT|EXCEPT] [ALL|DISTINCT]
override fun visitExprQuerySet(node: Expr.QuerySet, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
// visit body (SFW or other SQL set op)
t = visit(node.body, t)
// ORDER BY
t = if (node.orderBy != null) visitOrderBy(node.orderBy, t concat " ") else t
// LIMIT
t = if (node.limit != null) visitExprWrapped(node.limit, t concat " LIMIT ") else t
// OFFSET
t = if (node.offset != null) visitExprWrapped(node.offset, t concat " OFFSET ") else t
return t
}

// SELECT-FROM-WHERE

override fun visitQueryBodySFW(node: QueryBody.SFW, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
// SELECT
t = visit(node.select, t)
// EXCLUDE
t = node.exclude?.let { visit(it, t) } ?: t
// FROM
t = visit(node.from, t concat " FROM ")
// LET
t = if (node.let != null) visitLet(node.let, t concat " ") else t
// WHERE
t = if (node.where != null) visitExprWrapped(node.where, t concat " WHERE ") else t
// GROUP BY
t = if (node.groupBy != null) visitGroupBy(node.groupBy, t concat " ") else t
// HAVING
t = if (node.having != null) visitExprWrapped(node.having, t concat " HAVING ") else t
return t
}

override fun visitQueryBodySetOp(node: QueryBody.SetOp, tail: InternalSqlBlock): InternalSqlBlock {
val op = mutableListOf<String>()
when (node.outer) {
when (node.isOuter) {
true -> op.add("OUTER")
else -> {}
}
Expand All @@ -601,35 +635,6 @@ internal abstract class InternalSqlDialect : AstBaseVisitor<InternalSqlBlock, In
return t
}

// SELECT-FROM-WHERE

override fun visitExprSFW(node: Expr.SFW, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
// SELECT
t = visit(node.select, t)
// EXCLUDE
t = node.exclude?.let { visit(it, t) } ?: t
// FROM
t = visit(node.from, t concat " FROM ")
// LET
t = if (node.let != null) visitLet(node.let!!, t concat " ") else t
// WHERE
t = if (node.where != null) visitExprWrapped(node.where!!, t concat " WHERE ") else t
// GROUP BY
t = if (node.groupBy != null) visitGroupBy(node.groupBy!!, t concat " ") else t
// HAVING
t = if (node.having != null) visitExprWrapped(node.having!!, t concat " HAVING ") else t
// SET OP
t = if (node.setOp != null) visitExprSFWSetOp(node.setOp!!, t concat " ") else t
// ORDER BY
t = if (node.orderBy != null) visitOrderBy(node.orderBy!!, t concat " ") else t
// LIMIT
t = if (node.limit != null) visitExprWrapped(node.limit!!, t concat " LIMIT ") else t
// OFFSET
t = if (node.offset != null) visitExprWrapped(node.offset!!, t concat " OFFSET ") else t
return t
}

// SELECT

override fun visitSelectStar(node: Select.Star, tail: InternalSqlBlock): InternalSqlBlock {
Expand Down Expand Up @@ -761,17 +766,6 @@ internal abstract class InternalSqlDialect : AstBaseVisitor<InternalSqlBlock, In
return tail concat op
}

override fun visitExprSFWSetOp(node: Expr.SFW.SetOp, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
t = visitSetOp(node.type, t)
t = t concat InternalSqlBlock.Nest(
prefix = " (",
postfix = ")",
child = InternalSqlBlock.root().apply { visitExprSFW(node.operand, this) },
)
return t
}

// ORDER BY

override fun visitOrderBy(node: OrderBy, tail: InternalSqlBlock): InternalSqlBlock =
Expand Down
Loading
Loading