Skip to content

Commit

Permalink
ast, parser: fix parse join (#1129)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjhuang2016 authored Dec 21, 2020
1 parent 8fca2e9 commit 947cf4e
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 7 deletions.
58 changes: 57 additions & 1 deletion ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,63 @@ type Join struct {
// NaturalJoin represents join is natural join.
NaturalJoin bool
// StraightJoin represents a straight join.
StraightJoin bool
StraightJoin bool
ExplicitParens bool
}

// NewCrossJoin builds a cross join without `on` or `using` clause.
// If the right child is a join tree, we need to handle it differently to make the precedence get right.
// Here is the example: t1 join t2 join t3
// JOIN ON t2.a = t3.a
// t1 join / \
// t2 t3
// (left) (right)
//
// We can not build it directly to:
// JOIN
// / \
// t1 JOIN ON t2.a = t3.a
// / \
// t2 t3
// The precedence would be t1 join (t2 join t3 on t2.a=t3.a), not (t1 join t2) join t3 on t2.a=t3.a
// We need to find the left-most child of the right child, and build a cross join of the left-hand side
// of the left child(t1), and the right hand side with the original left-most child of the right child(t2).
// JOIN t2.a = t3.a
// / \
// JOIN t3
// / \
// t1 t2
// Besides, if the right handle side join tree's join type is right join and has explicit parentheses, we need to rewrite it to left join.
// So t1 join t2 right join t3 would be rewrite to t1 join t3 left join t2.
// If not, t1 join (t2 right join t3) would be (t1 join t2) right join t3. After rewrite the right join to left join.
// We get (t1 join t3) left join t2, the semantics is correct.
func NewCrossJoin(left, right ResultSetNode) (n *Join) {
rj, ok := right.(*Join)
if !ok || rj.Right == nil {
return &Join{Left: left, Right: right, Tp: CrossJoin}
}

var leftMostLeafFatherOfRight = rj
// Walk down the right hand side.
for {
if leftMostLeafFatherOfRight.Tp == RightJoin && leftMostLeafFatherOfRight.ExplicitParens {
// Rewrite right join to left join.
tmpChild := leftMostLeafFatherOfRight.Right
leftMostLeafFatherOfRight.Right = leftMostLeafFatherOfRight.Left
leftMostLeafFatherOfRight.Left = tmpChild
leftMostLeafFatherOfRight.Tp = LeftJoin
}
leftChild := leftMostLeafFatherOfRight.Left
if join, ok := leftChild.(*Join); ok && join.Right != nil {
leftMostLeafFatherOfRight = join
} else {
break
}
}

newCrossJoin := &Join{Left: left, Right: leftMostLeafFatherOfRight.Left, Tp: CrossJoin}
leftMostLeafFatherOfRight.Left = newCrossJoin
return rj
}

// Restore implements Node interface.
Expand Down
10 changes: 8 additions & 2 deletions ast/dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,21 @@ func (tc *testDMLSuite) TestJoinRestore(c *C) {
{"t1 inner join t2 using (b)", "`t1` JOIN `t2` USING (`b`)"},
{"t1 join t2 using (b,c) left join t3 on t1.a>t3.a", "(`t1` JOIN `t2` USING (`b`,`c`)) LEFT JOIN `t3` ON `t1`.`a`>`t3`.`a`"},
{"t1 natural join t2 right outer join t3 using (b,c)", "(`t1` NATURAL JOIN `t2`) RIGHT JOIN `t3` USING (`b`,`c`)"},
{"(a al left join b bl on al.a1 > bl.b1) join (a ar right join b br on ar.a1 > br.b1)", "(`a` AS `al` LEFT JOIN `b` AS `bl` ON `al`.`a1`>`bl`.`b1`) JOIN (`a` AS `ar` RIGHT JOIN `b` AS `br` ON `ar`.`a1`>`br`.`b1`)"},
{"a al left join b bl on al.a1 > bl.b1, a ar right join b br on ar.a1 > br.b1", "(`a` AS `al` LEFT JOIN `b` AS `bl` ON `al`.`a1`>`bl`.`b1`) JOIN (`a` AS `ar` RIGHT JOIN `b` AS `br` ON `ar`.`a1`>`br`.`b1`)"},
{"t1, t2", "(`t1`) JOIN `t2`"},
{"t1, t2, t3", "((`t1`) JOIN `t2`) JOIN `t3`"},
}
testChangedCases := []NodeRestoreTestCase{
{"(a al left join b bl on al.a1 > bl.b1) join (a ar right join b br on ar.a1 > br.b1)", "((`a` AS `al` LEFT JOIN `b` AS `bl` ON `al`.`a1`>`bl`.`b1`) JOIN `b` AS `br`) LEFT JOIN `a` AS `ar` ON `ar`.`a1`>`br`.`b1`"},
{"a al left join b bl on al.a1 > bl.b1, a ar right join b br on ar.a1 > br.b1", "(`a` AS `al` LEFT JOIN `b` AS `bl` ON `al`.`a1`>`bl`.`b1`) JOIN (`a` AS `ar` RIGHT JOIN `b` AS `br` ON `ar`.`a1`>`br`.`b1`)"},
{"t1 join (t2 right join t3 on t2.a > t3.a join (t4 right join t5 on t4.a > t5.a))", "(((`t1` JOIN `t2`) RIGHT JOIN `t3` ON `t2`.`a`>`t3`.`a`) JOIN `t5`) LEFT JOIN `t4` ON `t4`.`a`>`t5`.`a`"},
{"t1 join t2 right join t3 on t2.a=t3.a", "(`t1` JOIN `t2`) RIGHT JOIN `t3` ON `t2`.`a`=`t3`.`a`"},
{"t1 join (t2 right join t3 on t2.a=t3.a)", "(`t1` JOIN `t3`) LEFT JOIN `t2` ON `t2`.`a`=`t3`.`a`"},
}
extractNodeFunc := func(node Node) Node {
return node.(*SelectStmt).From.TableRefs
}
RunNodeRestoreTest(c, testCases, "select * from %s", extractNodeFunc)
RunNodeRestoreTestWithFlagsStmtChange(c, testChangedCases, "select * from %s", extractNodeFunc)
}

func (ts *testDMLSuite) TestTableRefsClauseRestore(c *C) {
Expand Down
22 changes: 22 additions & 0 deletions ast/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ func (checker *nodeTextCleaner) Enter(in Node) (out Node, skipChildren bool) {
for _, opt := range node.Options {
opt.StrValue = strings.ToLower(opt.StrValue)
}
case *Join:
node.ExplicitParens = false
}
return in, false
}
Expand Down Expand Up @@ -189,3 +191,23 @@ func RunNodeRestoreTestWithFlags(c *C, nodeTestCases []NodeRestoreTestCase, temp
c.Assert(stmt2, DeepEquals, stmt, comment)
}
}

// RunNodeRestoreTestWithFlagsStmtChange likes RunNodeRestoreTestWithFlags but not check if the ASTs are same.
// Sometimes the AST are different and it's expected.
func RunNodeRestoreTestWithFlagsStmtChange(c *C, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node) {
par := parser.New()
par.EnableWindowFunc(true)
for _, testCase := range nodeTestCases {
sourceSQL := fmt.Sprintf(template, testCase.sourceSQL)
expectSQL := fmt.Sprintf(template, testCase.expectSQL)
stmt, err := par.ParseOneStmt(sourceSQL, "", "")
comment := Commentf("source %#v", testCase)
c.Assert(err, IsNil, comment)
var sb strings.Builder
err = extractNodeFunc(stmt).Restore(NewRestoreCtx(DefaultRestoreFlags, &sb))
c.Assert(err, IsNil, comment)
restoreSql := fmt.Sprintf(template, sb.String())
comment = Commentf("source %#v; restore %v", testCase, restoreSql)
c.Assert(restoreSql, Equals, expectSQL, comment)
}
}
6 changes: 4 additions & 2 deletions parser.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -1338,10 +1338,10 @@ import (
%precedence remove
%precedence lowerThenOrder
%precedence order
%left join straightJoin inner cross left right full natural

/* A dummy token to force the priority of TableRef production in a join. */
%left tableRefPriority
%left join straightJoin inner cross left right full natural
%precedence lowerThanOn
%precedence on using
%right assignmentEq
Expand Down Expand Up @@ -8028,6 +8028,8 @@ TableFactor:
}
| '(' TableRefs ')'
{
j := $2.(*ast.Join)
j.ExplicitParens = true
$$ = $2
}

Expand Down Expand Up @@ -8140,7 +8142,7 @@ JoinTable:
/* Use %prec to evaluate production TableRef before cross join */
TableRef CrossOpt TableRef %prec tableRefPriority
{
$$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin}
$$ = ast.NewCrossJoin($1.(ast.ResultSetNode), $3.(ast.ResultSetNode))
}
| TableRef CrossOpt TableRef "ON" Expression
{
Expand Down
3 changes: 3 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ func (s *testParserSuite) TestDMLStmt(c *C) {
{"select * from t1 natural left outer join t2", true, "SELECT * FROM `t1` NATURAL LEFT JOIN `t2`"},
{"select * from t1 natural inner join t2", false, ""},
{"select * from t1 natural cross join t2", false, ""},
{"select * from t3 join t1 join t2 on t1.a=t2.a on t3.b=t2.b", true, "SELECT * FROM `t3` JOIN (`t1` JOIN `t2` ON `t1`.`a`=`t2`.`a`) ON `t3`.`b`=`t2`.`b`"},

// for straight_join
{"select * from t1 straight_join t2 on t1.id = t2.id", true, "SELECT * FROM `t1` STRAIGHT_JOIN `t2` ON `t1`.`id`=`t2`.`id`"},
Expand Down Expand Up @@ -5510,6 +5511,8 @@ func (checker *nodeTextCleaner) Enter(in ast.Node) (out ast.Node, skipChildren b
}
}
node.Specs = specs
case *ast.Join:
node.ExplicitParens = false
}
return in, false
}
Expand Down

0 comments on commit 947cf4e

Please sign in to comment.