From c6dd8cefc594b2920c396bba89c546d46c71879b Mon Sep 17 00:00:00 2001 From: PlanetScale Actions Bot <60239337+planetscale-actions-bot@users.noreply.github.com> Date: Thu, 29 Feb 2024 02:08:48 -0800 Subject: [PATCH] Bugfix: GROUP BY/HAVING alias resolution (#15344) (#4566) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Andres Taylor Co-authored-by: AndrĂ©s Taylor --- .../queries/aggregation/aggregation_test.go | 98 ++++++- go/vt/vtgate/executor_select_test.go | 12 +- .../planbuilder/testdata/aggr_cases.json | 146 ++++++---- .../planbuilder/testdata/filter_cases.json | 4 +- .../testdata/memory_sort_cases.json | 16 +- .../planbuilder/testdata/oltp_cases.json | 4 +- .../testdata/postprocess_cases.json | 32 +-- .../planbuilder/testdata/select_cases.json | 88 +++++- .../planbuilder/testdata/tpcc_cases.json | 6 +- .../planbuilder/testdata/tpch_cases.json | 10 +- .../planbuilder/testdata/union_cases.json | 12 +- go/vt/vtgate/semantics/analyzer.go | 25 +- go/vt/vtgate/semantics/analyzer_test.go | 60 ++-- go/vt/vtgate/semantics/binder.go | 170 ++++++++++- go/vt/vtgate/semantics/dependencies.go | 5 +- go/vt/vtgate/semantics/early_rewriter.go | 250 ++++++++++------ go/vt/vtgate/semantics/early_rewriter_test.go | 272 +++++++++++++++--- go/vt/vtgate/semantics/errors.go | 37 ++- go/vt/vtgate/semantics/scoper.go | 36 ++- go/vt/vtgate/semantics/vtable.go | 27 +- 20 files changed, 1004 insertions(+), 306 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index 3386e2210cc..fa5b3cbbe25 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -40,7 +40,21 @@ func start(t *testing.T) (utils.MySQLCompare, func()) { deleteAll := func() { _, _ = utils.ExecAllowError(t, mcmp.VtConn, "set workload = oltp") - tables := []string{"t9", "aggr_test", "t3", "t7_xxhash", "aggr_test_dates", "t7_xxhash_idx", "t1", "t2", "t10"} + tables := []string{ + "t3", + "t3_id7_idx", + "t9", + "aggr_test", + "aggr_test_dates", + "t7_xxhash", + "t7_xxhash_idx", + "t1", + "t2", + "t10", + "emp", + "dept", + "bet_logs", + } for _, table := range tables { _, _ = mcmp.ExecAndIgnore("delete from " + table) } @@ -668,3 +682,85 @@ func TestDistinctAggregation(t *testing.T) { }) } } + +func TestHavingQueries(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 18, "vtgate") + mcmp, closer := start(t) + defer closer() + + inserts := []string{ + `INSERT INTO emp (empno, ename, job, mgr, hiredate, sal, comm, deptno) VALUES + (1, 'John', 'Manager', NULL, '2022-01-01', 5000, 500, 1), + (2, 'Doe', 'Analyst', 1, '2023-01-01', 4500, NULL, 1), + (3, 'Jane', 'Clerk', 1, '2023-02-01', 3000, 200, 2), + (4, 'Mary', 'Analyst', 2, '2022-03-01', 4700, NULL, 1), + (5, 'Smith', 'Salesman', 3, '2023-01-15', 3200, 300, 3)`, + "INSERT INTO dept (deptno, dname, loc) VALUES (1, 'IT', 'New York'), (2, 'HR', 'London'), (3, 'Sales', 'San Francisco')", + "INSERT INTO t1 (t1_id, name, value, shardKey) VALUES (1, 'Name1', 'Value1', 100), (2, 'Name2', 'Value2', 100), (3, 'Name1', 'Value3', 200)", + "INSERT INTO aggr_test_dates (id, val1, val2) VALUES (1, '2023-01-01', '2023-01-02'), (2, '2023-02-01', '2023-02-02'), (3, '2023-03-01', '2023-03-02')", + "INSERT INTO t10 (k, a, b) VALUES (1, 10, 20), (2, 30, 40), (3, 50, 60)", + "INSERT INTO t3 (id5, id6, id7) VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300)", + "INSERT INTO t9 (id1, id2, id3) VALUES (1, 'A1', 'B1'), (2, 'A2', 'B2'), (3, 'A1', 'B3')", + "INSERT INTO aggr_test (id, val1, val2) VALUES (1, 'Test1', 100), (2, 'Test2', 200), (3, 'Test1', 300), (4, 'Test3', 400)", + "INSERT INTO t2 (id, shardKey) VALUES (1, 100), (2, 200), (3, 300)", + `INSERT INTO bet_logs (id, merchant_game_id, bet_amount, game_id) VALUES + (1, 1, 100.0, 10), + (2, 1, 200.0, 11), + (3, 2, 300.0, 10), + (4, 3, 400.0, 12)`, + } + + for _, insert := range inserts { + mcmp.Exec(insert) + } + + queries := []string{ + // The following queries are not allowed by MySQL but Vitess allows them + // SELECT ename FROM emp GROUP BY ename HAVING sal > 5000 + // SELECT val1, COUNT(val2) FROM aggr_test_dates GROUP BY val1 HAVING val2 > 5 + // SELECT k, a FROM t10 GROUP BY k HAVING b > 2 + // SELECT loc FROM dept GROUP BY loc HAVING COUNT(deptno) AND dname = 'Sales' + // SELECT AVG(val2) AS average_val2 FROM aggr_test HAVING val1 = 'Test' + + // these first queries are all failing in different ways. let's check that Vitess also fails + + "SELECT deptno, AVG(sal) AS average_salary HAVING average_salary > 5000 FROM emp", + "SELECT job, COUNT(empno) AS num_employees FROM emp HAVING num_employees > 2", + "SELECT dname, SUM(sal) FROM dept JOIN emp ON dept.deptno = emp.deptno HAVING AVG(sal) > 6000", + "SELECT COUNT(*) AS count FROM emp WHERE count > 5", + "SELECT `name`, AVG(`value`) FROM t1 GROUP BY `name` HAVING `name`", + "SELECT empno, MAX(sal) FROM emp HAVING COUNT(*) > 3", + "SELECT id, SUM(bet_amount) AS total_bets FROM bet_logs HAVING total_bets > 1000", + "SELECT merchant_game_id FROM bet_logs GROUP BY merchant_game_id HAVING SUM(bet_amount)", + "SELECT shardKey, COUNT(id) FROM t2 HAVING shardKey > 100", + "SELECT deptno FROM emp GROUP BY deptno HAVING MAX(hiredate) > '2020-01-01'", + + // These queries should not fail + "SELECT deptno, COUNT(*) AS num_employees FROM emp GROUP BY deptno HAVING num_employees > 5", + "SELECT ename, SUM(sal) FROM emp GROUP BY ename HAVING SUM(sal) > 10000", + "SELECT dname, AVG(sal) AS average_salary FROM emp JOIN dept ON emp.deptno = dept.deptno GROUP BY dname HAVING average_salary > 5000", + "SELECT dname, MAX(sal) AS max_salary FROM emp JOIN dept ON emp.deptno = dept.deptno GROUP BY dname HAVING max_salary < 10000", + "SELECT YEAR(hiredate) AS year, COUNT(*) FROM emp GROUP BY year HAVING COUNT(*) > 2", + "SELECT mgr, COUNT(empno) AS managed_employees FROM emp WHERE mgr IS NOT NULL GROUP BY mgr HAVING managed_employees >= 3", + "SELECT deptno, SUM(comm) AS total_comm FROM emp GROUP BY deptno HAVING total_comm > AVG(total_comm)", + "SELECT id2, COUNT(*) AS count FROM t9 GROUP BY id2 HAVING count > 1", + "SELECT val1, COUNT(*) FROM aggr_test GROUP BY val1 HAVING COUNT(*) > 1", + "SELECT DATE(val1) AS date, SUM(val2) FROM aggr_test_dates GROUP BY date HAVING SUM(val2) > 100", + "SELECT shardKey, AVG(`value`) FROM t1 WHERE `value` IS NOT NULL GROUP BY shardKey HAVING AVG(`value`) > 10", + "SELECT job, COUNT(*) AS job_count FROM emp GROUP BY job HAVING job_count > 3", + "SELECT b, AVG(a) AS avg_a FROM t10 GROUP BY b HAVING AVG(a) > 5", + "SELECT merchant_game_id, SUM(bet_amount) AS total_bets FROM bet_logs GROUP BY merchant_game_id HAVING total_bets > 1000", + "SELECT loc, COUNT(deptno) AS num_depts FROM dept GROUP BY loc HAVING num_depts > 1", + "SELECT `name`, COUNT(*) AS name_count FROM t1 GROUP BY `name` HAVING name_count > 2", + "SELECT COUNT(*) AS num_jobs FROM emp GROUP BY empno HAVING num_jobs > 1", + "SELECT id, COUNT(*) AS count FROM t2 GROUP BY id HAVING count > 1", + "SELECT val2, SUM(id) FROM aggr_test GROUP BY val2 HAVING SUM(id) > 10", + "SELECT game_id, COUNT(*) AS num_logs FROM bet_logs GROUP BY game_id HAVING num_logs > 5", + } + + for _, query := range queries { + mcmp.Run(query, func(mcmp *utils.MySQLCompare) { + mcmp.ExecAllowAndCompareError(query) + }) + } +} diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index a49e2592721..1f317d4c7fa 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -1926,7 +1926,7 @@ func TestSelectScatterOrderBy(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select col1, col2, weight_string(col2) from `user` order by col2 desc", + Sql: "select col1, col2, weight_string(col2) from `user` order by `user`.col2 desc", BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { @@ -1999,7 +1999,7 @@ func TestSelectScatterOrderByVarChar(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select col1, textcol, weight_string(textcol) from `user` order by textcol desc", + Sql: "select col1, textcol, weight_string(textcol) from `user` order by `user`.textcol desc", BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { @@ -2065,7 +2065,7 @@ func TestStreamSelectScatterOrderBy(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select id, col, weight_string(col) from `user` order by col desc", + Sql: "select id, col, weight_string(col) from `user` order by `user`.col desc", BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { @@ -2127,7 +2127,7 @@ func TestStreamSelectScatterOrderByVarChar(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select id, textcol, weight_string(textcol) from `user` order by textcol desc", + Sql: "select id, textcol, weight_string(textcol) from `user` order by `user`.textcol desc", BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { @@ -2323,7 +2323,7 @@ func TestSelectScatterLimit(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select col1, col2, weight_string(col2) from `user` order by col2 desc limit :__upper_limit", + Sql: "select col1, col2, weight_string(col2) from `user` order by `user`.col2 desc limit :__upper_limit", BindVariables: map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}, }} for _, conn := range conns { @@ -2395,7 +2395,7 @@ func TestStreamSelectScatterLimit(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select col1, col2, weight_string(col2) from `user` order by col2 desc limit :__upper_limit", + Sql: "select col1, col2, weight_string(col2) from `user` order by `user`.col2 desc limit :__upper_limit", BindVariables: map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}, }} for _, conn := range conns { diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 5644f3b8807..825c5b3b5a8 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -675,10 +675,15 @@ } }, { - "comment": "scatter aggregate group by aggregate function", + "comment": "scatter aggregate group by aggregate function - since we don't have authoratative columns for user, we can't be sure that the user isn't referring a column named b", "query": "select count(*) b from user group by b", "plan": "VT03005: cannot group on 'count(*)'" }, + { + "comment": "scatter aggregate group by aggregate function with column information", + "query": "select count(*) b from authoritative group by b", + "plan": "VT03005: cannot group on 'b'" + }, { "comment": "scatter aggregate multiple group by (columns)", "query": "select a, b, count(*) from user group by a, b", @@ -893,7 +898,7 @@ }, "FieldQuery": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` where 1 != 1 group by a, b, c, weight_string(a), weight_string(b), weight_string(c)", "OrderBy": "(0|5) ASC, (1|6) ASC, (2|7) ASC", - "Query": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` group by a, b, c, weight_string(a), weight_string(b), weight_string(c) order by a asc, b asc, c asc", + "Query": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` group by a, b, c, weight_string(a), weight_string(b), weight_string(c) order by `user`.a asc, `user`.b asc, `user`.c asc", "Table": "`user`" } ] @@ -925,7 +930,7 @@ }, "FieldQuery": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` where 1 != 1 group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c)", "OrderBy": "(3|5) ASC, (1|6) ASC, (0|7) ASC, (2|8) ASC", - "Query": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c) order by d asc, b asc, a asc, c asc", + "Query": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c) order by `user`.d asc, `user`.b asc, `user`.a asc, `user`.c asc", "Table": "`user`" } ] @@ -957,7 +962,7 @@ }, "FieldQuery": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` where 1 != 1 group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c)", "OrderBy": "(3|5) ASC, (1|6) ASC, (0|7) ASC, (2|8) ASC", - "Query": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c) order by d asc, b asc, a asc, c asc", + "Query": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c) order by `user`.d asc, `user`.b asc, `user`.a asc, `user`.c asc", "Table": "`user`" } ] @@ -989,7 +994,7 @@ }, "FieldQuery": "select a, b, c, count(*), weight_string(a), weight_string(c), weight_string(b) from `user` where 1 != 1 group by a, c, b, weight_string(a), weight_string(c), weight_string(b)", "OrderBy": "(0|4) DESC, (2|5) DESC, (1|6) ASC", - "Query": "select a, b, c, count(*), weight_string(a), weight_string(c), weight_string(b) from `user` group by a, c, b, weight_string(a), weight_string(c), weight_string(b) order by a desc, c desc, b asc", + "Query": "select a, b, c, count(*), weight_string(a), weight_string(c), weight_string(b) from `user` group by a, c, b, weight_string(a), weight_string(c), weight_string(b) order by a desc, c desc, `user`.b asc", "Table": "`user`" } ] @@ -1041,32 +1046,6 @@ ] } }, - { - "comment": "Group by with collate operator", - "query": "select user.col1 as a from user where user.id = 5 group by a collate utf8_general_ci", - "plan": { - "QueryType": "SELECT", - "Original": "select user.col1 as a from user where user.id = 5 group by a collate utf8_general_ci", - "Instructions": { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col1 as a from `user` where 1 != 1 group by `user`.col1 collate utf8_general_ci", - "Query": "select `user`.col1 as a from `user` where `user`.id = 5 group by `user`.col1 collate utf8_general_ci", - "Table": "`user`", - "Values": [ - "5" - ], - "Vindex": "user_index" - }, - "TablesUsed": [ - "user.user" - ] - } - }, { "comment": "routing rules for aggregates", "query": "select id, count(*) from route2 group by id", @@ -1103,7 +1082,7 @@ "Sharded": true }, "FieldQuery": "select col from ref where 1 != 1", - "Query": "select col from ref order by col asc", + "Query": "select col from ref order by ref.col asc", "Table": "ref" }, "TablesUsed": [ @@ -1584,10 +1563,10 @@ }, { "comment": "weight_string addition to group by", - "query": "select lower(textcol1) as v, count(*) from user group by v", + "query": "select lower(col1) as v, count(*) from authoritative group by v", "plan": { "QueryType": "SELECT", - "Original": "select lower(textcol1) as v, count(*) from user group by v", + "Original": "select lower(col1) as v, count(*) from authoritative group by v", "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", @@ -1602,24 +1581,24 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select lower(textcol1) as v, count(*), weight_string(lower(textcol1)) from `user` where 1 != 1 group by lower(textcol1), weight_string(lower(textcol1))", + "FieldQuery": "select lower(col1) as v, count(*), weight_string(lower(col1)) from authoritative where 1 != 1 group by lower(col1), weight_string(lower(col1))", "OrderBy": "(0|2) ASC", - "Query": "select lower(textcol1) as v, count(*), weight_string(lower(textcol1)) from `user` group by lower(textcol1), weight_string(lower(textcol1)) order by lower(textcol1) asc", - "Table": "`user`" + "Query": "select lower(col1) as v, count(*), weight_string(lower(col1)) from authoritative group by lower(col1), weight_string(lower(col1)) order by lower(col1) asc", + "Table": "authoritative" } ] }, "TablesUsed": [ - "user.user" + "user.authoritative" ] } }, { "comment": "weight_string addition to group by when also there in order by", - "query": "select char_length(texcol1) as a, count(*) from user group by a order by a", + "query": "select char_length(col1) as a, count(*) from authoritative group by a order by a", "plan": { "QueryType": "SELECT", - "Original": "select char_length(texcol1) as a, count(*) from user group by a order by a", + "Original": "select char_length(col1) as a, count(*) from authoritative group by a order by a", "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", @@ -1634,15 +1613,15 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select char_length(texcol1) as a, count(*), weight_string(char_length(texcol1)) from `user` where 1 != 1 group by char_length(texcol1), weight_string(char_length(texcol1))", + "FieldQuery": "select char_length(col1) as a, count(*), weight_string(char_length(col1)) from authoritative where 1 != 1 group by char_length(col1), weight_string(char_length(col1))", "OrderBy": "(0|2) ASC", - "Query": "select char_length(texcol1) as a, count(*), weight_string(char_length(texcol1)) from `user` group by char_length(texcol1), weight_string(char_length(texcol1)) order by char_length(texcol1) asc", - "Table": "`user`" + "Query": "select char_length(col1) as a, count(*), weight_string(char_length(col1)) from authoritative group by char_length(col1), weight_string(char_length(col1)) order by char_length(authoritative.col1) asc", + "Table": "authoritative" } ] }, "TablesUsed": [ - "user.user" + "user.authoritative" ] } }, @@ -1699,7 +1678,7 @@ }, "FieldQuery": "select col, id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(1|2) ASC", - "Query": "select col, id, weight_string(id) from `user` order by id asc", + "Query": "select col, id, weight_string(id) from `user` order by `user`.id asc", "ResultColumns": 2, "Table": "`user`" }, @@ -2009,19 +1988,20 @@ }, { "comment": "Less Equal filter on scatter with grouping", - "query": "select col, count(*) a from user group by col having a <= 10", + "query": "select col1, count(*) a from user group by col1 having a <= 10", "plan": { "QueryType": "SELECT", - "Original": "select col, count(*) a from user group by col having a <= 10", + "Original": "select col1, count(*) a from user group by col1 having a <= 10", "Instructions": { "OperatorType": "Filter", "Predicate": "count(*) <= 10", + "ResultColumns": 2, "Inputs": [ { "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "sum_count_star(1) AS a", - "GroupBy": "0", + "GroupBy": "(0|2)", "Inputs": [ { "OperatorType": "Route", @@ -2030,9 +2010,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(*) as a from `user` where 1 != 1 group by col", - "OrderBy": "0 ASC", - "Query": "select col, count(*) as a from `user` group by col order by col asc", + "FieldQuery": "select col1, count(*) as a, weight_string(col1) from `user` where 1 != 1 group by col1, weight_string(col1)", + "OrderBy": "(0|2) ASC", + "Query": "select col1, count(*) as a, weight_string(col1) from `user` group by col1, weight_string(col1) order by col1 asc", "Table": "`user`" } ] @@ -2046,10 +2026,10 @@ }, { "comment": "We should be able to find grouping keys on ordered aggregates", - "query": "select count(*) as a, val1 from user group by val1 having a = 1.00", + "query": "select count(*) as a, col2 from user group by col2 having a = 1.00", "plan": { "QueryType": "SELECT", - "Original": "select count(*) as a, val1 from user group by val1 having a = 1.00", + "Original": "select count(*) as a, col2 from user group by col2 having a = 1.00", "Instructions": { "OperatorType": "Filter", "Predicate": "count(*) = 1.00", @@ -2068,9 +2048,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select count(*) as a, val1, weight_string(val1) from `user` where 1 != 1 group by val1, weight_string(val1)", + "FieldQuery": "select count(*) as a, col2, weight_string(col2) from `user` where 1 != 1 group by col2, weight_string(col2)", "OrderBy": "(1|2) ASC", - "Query": "select count(*) as a, val1, weight_string(val1) from `user` group by val1, weight_string(val1) order by val1 asc", + "Query": "select count(*) as a, col2, weight_string(col2) from `user` group by col2, weight_string(col2) order by col2 asc", "Table": "`user`" } ] @@ -2620,10 +2600,10 @@ }, { "comment": "group by column alias", - "query": "select ascii(val1) as a, count(*) from user group by a", + "query": "select ascii(col2) as a, count(*) from user group by a", "plan": { "QueryType": "SELECT", - "Original": "select ascii(val1) as a, count(*) from user group by a", + "Original": "select ascii(col2) as a, count(*) from user group by a", "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", @@ -2638,9 +2618,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select ascii(val1) as a, count(*), weight_string(ascii(val1)) from `user` where 1 != 1 group by ascii(val1), weight_string(ascii(val1))", + "FieldQuery": "select ascii(col2) as a, count(*), weight_string(ascii(col2)) from `user` where 1 != 1 group by ascii(col2), weight_string(ascii(col2))", "OrderBy": "(0|2) ASC", - "Query": "select ascii(val1) as a, count(*), weight_string(ascii(val1)) from `user` group by ascii(val1), weight_string(ascii(val1)) order by ascii(val1) asc", + "Query": "select ascii(col2) as a, count(*), weight_string(ascii(col2)) from `user` group by ascii(col2), weight_string(ascii(col2)) order by ascii(col2) asc", "Table": "`user`" } ] @@ -2984,7 +2964,7 @@ "Original": "select foo, sum(foo) as fooSum, sum(bar) as barSum from user group by foo having fooSum+sum(bar) = 42", "Instructions": { "OperatorType": "Filter", - "Predicate": "sum(foo) + sum(bar) = 42", + "Predicate": "sum(`user`.foo) + sum(bar) = 42", "ResultColumns": 3, "Inputs": [ { @@ -3328,10 +3308,10 @@ }, { "comment": "group by and ',' joins", - "query": "select user.id from user, user_extra group by id", + "query": "select user.id from user, user_extra group by user.id", "plan": { "QueryType": "SELECT", - "Original": "select user.id from user, user_extra group by id", + "Original": "select user.id from user, user_extra group by user.id", "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", @@ -3555,7 +3535,7 @@ }, "FieldQuery": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where 1 != 1) as x where 1 != 1", "OrderBy": "(1|3) ASC", - "Query": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where val2 < 4) as x order by val1 asc limit :__upper_limit", + "Query": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where val2 < 4) as x order by `user`.val1 asc limit :__upper_limit", "Table": "`user`" } ] @@ -6938,5 +6918,45 @@ "user.user" ] } + }, + { + "comment": "col is a column on user, but the HAVING is referring to an alias", + "query": "select sum(x) col from user where x > 0 having col = 2", + "plan": { + "QueryType": "SELECT", + "Original": "select sum(x) col from user where x > 0 having col = 2", + "Instructions": { + "OperatorType": "Filter", + "Predicate": "sum(`user`.x) = 2", + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS col", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(x) as col from `user` where 1 != 1", + "Query": "select sum(x) as col from `user` where x > 0", + "Table": "`user`" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "baz in the HAVING clause can't be accessed because of the GROUP BY", + "query": "select foo, count(bar) as x from user group by foo having baz > avg(baz) order by x", + "plan": "Unknown column 'baz' in 'having clause'" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.json b/go/vt/vtgate/planbuilder/testdata/filter_cases.json index cb64ba4e522..ace1fffcfa1 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.json @@ -4018,7 +4018,7 @@ "Sharded": true }, "FieldQuery": "select a + 2 as a from `user` where 1 != 1", - "Query": "select a + 2 as a from `user` where a + 2 = 42", + "Query": "select a + 2 as a from `user` where `user`.a + 2 = 42", "Table": "`user`" }, "TablesUsed": [ @@ -4041,7 +4041,7 @@ }, "FieldQuery": "select a + 2 as a, weight_string(a + 2) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select a + 2 as a, weight_string(a + 2) from `user` order by a + 2 asc", + "Query": "select a + 2 as a, weight_string(a + 2) from `user` order by `user`.a + 2 asc", "ResultColumns": 1, "Table": "`user`" }, diff --git a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json index 3ca7e1059e4..4a879997925 100644 --- a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json @@ -24,9 +24,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, count(*), weight_string(a), weight_string(b) from `user` where 1 != 1 group by a, weight_string(a)", + "FieldQuery": "select a, b, count(*), weight_string(a), weight_string(`user`.b) from `user` where 1 != 1 group by a, weight_string(a)", "OrderBy": "(0|3) ASC", - "Query": "select a, b, count(*), weight_string(a), weight_string(b) from `user` group by a, weight_string(a) order by a asc", + "Query": "select a, b, count(*), weight_string(a), weight_string(`user`.b) from `user` group by a, weight_string(a) order by a asc", "Table": "`user`" } ] @@ -102,9 +102,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, count(*) as k, weight_string(a), weight_string(b) from `user` where 1 != 1 group by a, weight_string(a)", + "FieldQuery": "select a, b, count(*) as k, weight_string(a), weight_string(`user`.b) from `user` where 1 != 1 group by a, weight_string(a)", "OrderBy": "(0|3) ASC", - "Query": "select a, b, count(*) as k, weight_string(a), weight_string(b) from `user` group by a, weight_string(a) order by a asc", + "Query": "select a, b, count(*) as k, weight_string(a), weight_string(`user`.b) from `user` group by a, weight_string(a) order by a asc", "Table": "`user`" } ] @@ -259,7 +259,7 @@ }, "FieldQuery": "select id, weight_string(id) from (select `user`.id, `user`.col from `user` where 1 != 1) as t where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select id, weight_string(id) from (select `user`.id, `user`.col from `user`) as t order by id asc", + "Query": "select id, weight_string(id) from (select `user`.id, `user`.col from `user`) as t order by t.id asc", "Table": "`user`" }, { @@ -552,9 +552,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, convert(a, binary), weight_string(convert(a, binary)) from `user` where 1 != 1", + "FieldQuery": "select a, convert(`user`.a, binary), weight_string(convert(`user`.a, binary)) from `user` where 1 != 1", "OrderBy": "(1|2) DESC", - "Query": "select a, convert(a, binary), weight_string(convert(a, binary)) from `user` order by convert(a, binary) desc", + "Query": "select a, convert(`user`.a, binary), weight_string(convert(`user`.a, binary)) from `user` order by convert(`user`.a, binary) desc", "ResultColumns": 1, "Table": "`user`" }, @@ -624,7 +624,7 @@ }, "FieldQuery": "select id, intcol from `user` where 1 != 1", "OrderBy": "1 ASC", - "Query": "select id, intcol from `user` order by intcol asc", + "Query": "select id, intcol from `user` order by `user`.intcol asc", "Table": "`user`" }, "TablesUsed": [ diff --git a/go/vt/vtgate/planbuilder/testdata/oltp_cases.json b/go/vt/vtgate/planbuilder/testdata/oltp_cases.json index 3af909415f9..45f1ac8c618 100644 --- a/go/vt/vtgate/planbuilder/testdata/oltp_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/oltp_cases.json @@ -91,7 +91,7 @@ }, "FieldQuery": "select c from sbtest1 where 1 != 1", "OrderBy": "0 ASC COLLATE latin1_swedish_ci", - "Query": "select c from sbtest1 where id between 50 and 235 order by c asc", + "Query": "select c from sbtest1 where id between 50 and 235 order by sbtest1.c asc", "Table": "sbtest1" }, "TablesUsed": [ @@ -119,7 +119,7 @@ }, "FieldQuery": "select c from sbtest30 where 1 != 1 group by c", "OrderBy": "0 ASC COLLATE latin1_swedish_ci", - "Query": "select c from sbtest30 where id between 1 and 10 group by c order by c asc", + "Query": "select c from sbtest30 where id between 1 and 10 group by c order by sbtest30.c asc", "Table": "sbtest30" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json index cef721e1c6a..78144f5c33c 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json @@ -145,7 +145,7 @@ "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where :__sq_has_values and id in ::__vals", + "Query": "select id from `user` where :__sq_has_values and `user`.id in ::__vals", "Table": "`user`", "Values": [ "::__sq1" @@ -226,7 +226,7 @@ }, "FieldQuery": "select col from `user` where 1 != 1", "OrderBy": "0 ASC", - "Query": "select col from `user` order by col asc", + "Query": "select col from `user` order by `user`.col asc", "Table": "`user`" }, "TablesUsed": [ @@ -249,7 +249,7 @@ }, "FieldQuery": "select user_id, col1, col2, weight_string(user_id) from authoritative where 1 != 1", "OrderBy": "(0|3) ASC", - "Query": "select user_id, col1, col2, weight_string(user_id) from authoritative order by user_id asc", + "Query": "select user_id, col1, col2, weight_string(user_id) from authoritative order by authoritative.user_id asc", "ResultColumns": 3, "Table": "authoritative" }, @@ -273,7 +273,7 @@ }, "FieldQuery": "select user_id, col1, col2 from authoritative where 1 != 1", "OrderBy": "1 ASC COLLATE latin1_swedish_ci", - "Query": "select user_id, col1, col2 from authoritative order by col1 asc", + "Query": "select user_id, col1, col2 from authoritative order by authoritative.col1 asc", "Table": "authoritative" }, "TablesUsed": [ @@ -296,7 +296,7 @@ }, "FieldQuery": "select a, textcol1, b, weight_string(a), weight_string(b) from `user` where 1 != 1", "OrderBy": "(0|3) ASC, 1 ASC COLLATE latin1_swedish_ci, (2|4) ASC", - "Query": "select a, textcol1, b, weight_string(a), weight_string(b) from `user` order by a asc, textcol1 asc, b asc", + "Query": "select a, textcol1, b, weight_string(a), weight_string(b) from `user` order by `user`.a asc, `user`.textcol1 asc, `user`.b asc", "ResultColumns": 3, "Table": "`user`" }, @@ -320,7 +320,7 @@ }, "FieldQuery": "select a, `user`.textcol1, b, weight_string(a), weight_string(b) from `user` where 1 != 1", "OrderBy": "(0|3) ASC, 1 ASC COLLATE latin1_swedish_ci, (2|4) ASC", - "Query": "select a, `user`.textcol1, b, weight_string(a), weight_string(b) from `user` order by a asc, `user`.textcol1 asc, b asc", + "Query": "select a, `user`.textcol1, b, weight_string(a), weight_string(b) from `user` order by `user`.a asc, `user`.textcol1 asc, `user`.b asc", "ResultColumns": 3, "Table": "`user`" }, @@ -344,7 +344,7 @@ }, "FieldQuery": "select a, textcol1, b, textcol2, weight_string(a), weight_string(b), weight_string(textcol2) from `user` where 1 != 1", "OrderBy": "(0|4) ASC, 1 ASC COLLATE latin1_swedish_ci, (2|5) ASC, (3|6) ASC COLLATE ", - "Query": "select a, textcol1, b, textcol2, weight_string(a), weight_string(b), weight_string(textcol2) from `user` order by a asc, textcol1 asc, b asc, textcol2 asc", + "Query": "select a, textcol1, b, textcol2, weight_string(a), weight_string(b), weight_string(textcol2) from `user` order by `user`.a asc, `user`.textcol1 asc, `user`.b asc, `user`.textcol2 asc", "ResultColumns": 4, "Table": "`user`" }, @@ -440,7 +440,7 @@ }, "FieldQuery": "select col from `user` where 1 != 1", "OrderBy": "0 ASC", - "Query": "select col from `user` where :__sq_has_values and col in ::__sq1 order by col asc", + "Query": "select col from `user` where :__sq_has_values and col in ::__sq1 order by `user`.col asc", "Table": "`user`" } ] @@ -1079,7 +1079,7 @@ "Sharded": true }, "FieldQuery": "select col from `user` as route1 where 1 != 1", - "Query": "select col from `user` as route1 where id = 1 order by col asc", + "Query": "select col from `user` as route1 where id = 1 order by route1.col asc", "Table": "`user`", "Values": [ "1" @@ -1365,7 +1365,7 @@ }, "FieldQuery": "select id as foo, weight_string(id) from music where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select id as foo, weight_string(id) from music order by id asc", + "Query": "select id as foo, weight_string(id) from music order by music.id asc", "ResultColumns": 1, "Table": "music" }, @@ -1389,7 +1389,7 @@ }, "FieldQuery": "select id as foo, id2 as id, weight_string(id2) from music where 1 != 1", "OrderBy": "(1|2) ASC", - "Query": "select id as foo, id2 as id, weight_string(id2) from music order by id2 asc", + "Query": "select id as foo, id2 as id, weight_string(id2) from music order by music.id2 asc", "ResultColumns": 2, "Table": "music" }, @@ -1419,7 +1419,7 @@ }, "FieldQuery": "select `name`, weight_string(`name`) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select `name`, weight_string(`name`) from `user` order by `name` asc", + "Query": "select `name`, weight_string(`name`) from `user` order by `user`.`name` asc", "Table": "`user`" }, { @@ -1606,7 +1606,7 @@ }, "FieldQuery": "select `name`, weight_string(`name`) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select `name`, weight_string(`name`) from `user` order by `name` asc", + "Query": "select `name`, weight_string(`name`) from `user` order by `user`.`name` asc", "Table": "`user`" }, { @@ -1645,7 +1645,7 @@ }, "FieldQuery": "select id, id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|2) ASC", - "Query": "select id, id, weight_string(id) from `user` order by id asc", + "Query": "select id, id, weight_string(id) from `user` order by `user`.id asc", "ResultColumns": 2, "Table": "`user`" }, @@ -2095,7 +2095,7 @@ }, "FieldQuery": "select col from `user` where 1 != 1 group by col", "OrderBy": "0 ASC", - "Query": "select col from `user` where id between :vtg1 and :vtg2 group by col order by col asc", + "Query": "select col from `user` where id between :vtg1 and :vtg2 group by col order by `user`.col asc", "Table": "`user`" } ] @@ -2126,7 +2126,7 @@ }, "FieldQuery": "select foo, col, weight_string(foo) from `user` where 1 != 1 group by col, foo, weight_string(foo)", "OrderBy": "1 ASC, (0|2) ASC", - "Query": "select foo, col, weight_string(foo) from `user` where id between :vtg1 and :vtg2 group by col, foo, weight_string(foo) order by col asc, foo asc", + "Query": "select foo, col, weight_string(foo) from `user` where id between :vtg1 and :vtg2 group by col, foo, weight_string(foo) order by `user`.col asc, foo asc", "Table": "`user`" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index d4474d975a1..0411c2295b6 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -1000,7 +1000,7 @@ }, "FieldQuery": "select user_id, weight_string(user_id) from music where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select user_id, weight_string(user_id) from music order by user_id asc limit :__upper_limit", + "Query": "select user_id, weight_string(user_id) from music order by music.user_id asc limit :__upper_limit", "ResultColumns": 1, "Table": "music" } @@ -1806,7 +1806,7 @@ }, "FieldQuery": "select user_id, count(id), weight_string(user_id) from music where 1 != 1 group by user_id", "OrderBy": "(0|2) ASC", - "Query": "select user_id, count(id), weight_string(user_id) from music group by user_id having count(user_id) = 1 order by user_id asc limit :__upper_limit", + "Query": "select user_id, count(id), weight_string(user_id) from music group by user_id having count(user_id) = 1 order by music.user_id asc limit :__upper_limit", "ResultColumns": 2, "Table": "music" } @@ -2115,7 +2115,79 @@ { "comment": "select (select col from user limit 1) as a from user join user_extra order by a", "query": "select (select col from user limit 1) as a from user join user_extra order by a", - "plan": "VT12001: unsupported: correlated subquery is only supported for EXISTS" + "plan": { + "QueryType": "SELECT", + "Original": "select (select col from user limit 1) as a from user join user_extra order by a", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(0|1) ASC", + "Inputs": [ + { + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Limit", + "Count": "1", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col from `user` where 1 != 1", + "Query": "select col from `user` limit :__upper_limit", + "Table": "`user`" + } + ] + }, + { + "InputName": "Outer", + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select :__sq1 as a, weight_string(:__sq1) from `user` where 1 != 1", + "Query": "select :__sq1 as a, weight_string(:__sq1) from `user`", + "Table": "`user`" + } + ] + } + ] + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra", + "Table": "user_extra" + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } }, { "comment": "plan test for a natural character set string", @@ -2270,7 +2342,7 @@ }, "FieldQuery": "select col, `user`.id from `user` where 1 != 1", "OrderBy": "0 ASC", - "Query": "select col, `user`.id from `user` order by col asc", + "Query": "select col, `user`.id from `user` order by `user`.col asc", "Table": "`user`" }, { @@ -2696,7 +2768,7 @@ }, "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select id, weight_string(id) from `user` order by id asc limit :__upper_limit", + "Query": "select id, weight_string(id) from `user` order by `user`.id asc limit :__upper_limit", "Table": "`user`" } ] @@ -2709,8 +2781,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select :__sq1 as `(select id from ``user`` order by id asc limit 1)` from user_extra where 1 != 1", - "Query": "select :__sq1 as `(select id from ``user`` order by id asc limit 1)` from user_extra", + "FieldQuery": "select :__sq1 as `(select id from ``user`` order by ``user``.id asc limit 1)` from user_extra where 1 != 1", + "Query": "select :__sq1 as `(select id from ``user`` order by ``user``.id asc limit 1)` from user_extra", "Table": "user_extra" } ] @@ -3186,7 +3258,7 @@ }, "FieldQuery": "select id, `name`, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|2) ASC", - "Query": "select id, `name`, weight_string(id) from `user` where `name` = 'aa' order by id asc limit :__upper_limit", + "Query": "select id, `name`, weight_string(id) from `user` where `name` = 'aa' order by `user`.id asc limit :__upper_limit", "ResultColumns": 2, "Table": "`user`" } diff --git a/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json b/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json index ee38e7d0538..f6072bcd9a5 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json @@ -556,7 +556,7 @@ "Sharded": true }, "FieldQuery": "select c_balance, c_first, c_middle, c_id from customer1 where 1 != 1", - "Query": "select c_balance, c_first, c_middle, c_id from customer1 where c_w_id = 840 and c_d_id = 1 and c_last = 'test' order by c_first asc", + "Query": "select c_balance, c_first, c_middle, c_id from customer1 where c_w_id = 840 and c_d_id = 1 and c_last = 'test' order by customer1.c_first asc", "Table": "customer1", "Values": [ "840" @@ -608,7 +608,7 @@ "Sharded": true }, "FieldQuery": "select o_id, o_carrier_id, o_entry_d from orders1 where 1 != 1", - "Query": "select o_id, o_carrier_id, o_entry_d from orders1 where o_w_id = 9894 and o_d_id = 3 and o_c_id = 159 order by o_id desc", + "Query": "select o_id, o_carrier_id, o_entry_d from orders1 where o_w_id = 9894 and o_d_id = 3 and o_c_id = 159 order by orders1.o_id desc", "Table": "orders1", "Values": [ "9894" @@ -660,7 +660,7 @@ "Sharded": true }, "FieldQuery": "select no_o_id from new_orders1 where 1 != 1", - "Query": "select no_o_id from new_orders1 where no_d_id = 689 and no_w_id = 15 order by no_o_id asc limit 1 for update", + "Query": "select no_o_id from new_orders1 where no_d_id = 689 and no_w_id = 15 order by new_orders1.no_o_id asc limit 1 for update", "Table": "new_orders1", "Values": [ "15" diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index 2d225808992..609285c4bfe 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -35,7 +35,7 @@ }, "FieldQuery": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, sum(l_quantity) as avg_qty, sum(l_extendedprice) as avg_price, sum(l_discount) as avg_disc, count(*) as count_order, count(l_quantity), count(l_extendedprice), count(l_discount), weight_string(l_returnflag), weight_string(l_linestatus) from lineitem where 1 != 1 group by l_returnflag, l_linestatus, weight_string(l_returnflag), weight_string(l_linestatus)", "OrderBy": "(0|13) ASC, (1|14) ASC", - "Query": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, sum(l_quantity) as avg_qty, sum(l_extendedprice) as avg_price, sum(l_discount) as avg_disc, count(*) as count_order, count(l_quantity), count(l_extendedprice), count(l_discount), weight_string(l_returnflag), weight_string(l_linestatus) from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus, weight_string(l_returnflag), weight_string(l_linestatus) order by l_returnflag asc, l_linestatus asc", + "Query": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, sum(l_quantity) as avg_qty, sum(l_extendedprice) as avg_price, sum(l_discount) as avg_disc, count(*) as count_order, count(l_quantity), count(l_extendedprice), count(l_discount), weight_string(l_returnflag), weight_string(l_linestatus) from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus, weight_string(l_returnflag), weight_string(l_linestatus) order by lineitem.l_returnflag asc, lineitem.l_linestatus asc", "Table": "lineitem" } ] @@ -213,7 +213,7 @@ }, "FieldQuery": "select o_orderpriority, count(*) as order_count, o_orderkey, weight_string(o_orderpriority) from orders where 1 != 1 group by o_orderpriority, o_orderkey, weight_string(o_orderpriority)", "OrderBy": "(0|3) ASC", - "Query": "select o_orderpriority, count(*) as order_count, o_orderkey, weight_string(o_orderpriority) from orders where o_orderdate >= date('1993-07-01') and o_orderdate < date('1993-07-01') + interval '3' month group by o_orderpriority, o_orderkey, weight_string(o_orderpriority) order by o_orderpriority asc", + "Query": "select o_orderpriority, count(*) as order_count, o_orderkey, weight_string(o_orderpriority) from orders where o_orderdate >= date('1993-07-01') and o_orderdate < date('1993-07-01') + interval '3' month group by o_orderpriority, o_orderkey, weight_string(o_orderpriority) order by orders.o_orderpriority asc", "Table": "orders" }, { @@ -631,9 +631,9 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where 1 != 1) as shipping where 1 != 1 group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year)", + "FieldQuery": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), shipping.supp_nation, weight_string(shipping.supp_nation), shipping.cust_nation, weight_string(shipping.cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where 1 != 1) as shipping where 1 != 1 group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year)", "OrderBy": "(5|6) ASC, (7|8) ASC, (1|4) ASC", - "Query": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31')) as shipping group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year) order by supp_nation asc, cust_nation asc, l_year asc", + "Query": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), shipping.supp_nation, weight_string(shipping.supp_nation), shipping.cust_nation, weight_string(shipping.cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31')) as shipping group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year) order by shipping.supp_nation asc, shipping.cust_nation asc, shipping.l_year asc", "Table": "lineitem" }, { @@ -1518,7 +1518,7 @@ }, "FieldQuery": "select s_suppkey, s_name, s_address, s_phone, total_revenue, weight_string(s_suppkey) from supplier, revenue0 where 1 != 1", "OrderBy": "(0|5) ASC", - "Query": "select s_suppkey, s_name, s_address, s_phone, total_revenue, weight_string(s_suppkey) from supplier, revenue0 where s_suppkey = supplier_no and total_revenue = :__sq1 order by s_suppkey asc", + "Query": "select s_suppkey, s_name, s_address, s_phone, total_revenue, weight_string(s_suppkey) from supplier, revenue0 where s_suppkey = supplier_no and total_revenue = :__sq1 order by supplier.s_suppkey asc", "ResultColumns": 5, "Table": "revenue0, supplier" } diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.json b/go/vt/vtgate/planbuilder/testdata/union_cases.json index 0860c5ca2ff..5d90a72a7dc 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.json @@ -128,7 +128,7 @@ }, "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|1) DESC", - "Query": "select id, weight_string(id) from `user` order by id desc limit :__upper_limit", + "Query": "select id, weight_string(id) from `user` order by `user`.id desc limit :__upper_limit", "Table": "`user`" } ] @@ -146,7 +146,7 @@ }, "FieldQuery": "select id, weight_string(id) from music where 1 != 1", "OrderBy": "(0|1) DESC", - "Query": "select id, weight_string(id) from music order by id desc limit :__upper_limit", + "Query": "select id, weight_string(id) from music order by music.id desc limit :__upper_limit", "Table": "music" } ] @@ -258,7 +258,7 @@ }, "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select id, weight_string(id) from `user` order by id asc limit :__upper_limit", + "Query": "select id, weight_string(id) from `user` order by `user`.id asc limit :__upper_limit", "Table": "`user`" } ] @@ -276,7 +276,7 @@ }, "FieldQuery": "select id, weight_string(id) from music where 1 != 1", "OrderBy": "(0|1) DESC", - "Query": "select id, weight_string(id) from music order by id desc limit :__upper_limit", + "Query": "select id, weight_string(id) from music order by music.id desc limit :__upper_limit", "Table": "music" } ] @@ -962,7 +962,7 @@ }, "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select id, weight_string(id) from `user` order by id asc limit :__upper_limit", + "Query": "select id, weight_string(id) from `user` order by `user`.id asc limit :__upper_limit", "Table": "`user`" } ] @@ -980,7 +980,7 @@ }, "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|1) DESC", - "Query": "select id, weight_string(id) from `user` order by id desc limit :__upper_limit", + "Query": "select id, weight_string(id) from `user` order by `user`.id desc limit :__upper_limit", "Table": "`user`" } ] diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index e5792ec6e56..d240cf8d140 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -37,6 +37,7 @@ type analyzer struct { sig QuerySignature si SchemaInformation currentDb string + recheck bool err error inProjection int @@ -73,7 +74,8 @@ func (a *analyzer) lateInit() { binder: a.binder, expandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{}, aliasMapCache: map[*sqlparser.Select]map[string]exprContainer{}, - reAnalyze: a.lateAnalyze, + reAnalyze: a.reAnalyze, + tables: a.tables, } } @@ -236,9 +238,12 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { return false } - if err := a.rewriter.up(cursor); err != nil { - a.setError(err) - return true + if !a.recheck { + // no need to run the rewriter on rechecking + if err := a.rewriter.up(cursor); err != nil { + a.setError(err) + return true + } } if err := a.scoper.up(cursor); err != nil { @@ -340,6 +345,14 @@ func (a *analyzer) lateAnalyze(statement sqlparser.SQLNode) error { return a.err } +func (a *analyzer) reAnalyze(statement sqlparser.SQLNode) error { + a.recheck = true + defer func() { + a.recheck = false + }() + return a.lateAnalyze(statement) +} + // canShortCut checks if we are dealing with a single unsharded keyspace and no tables that have managed foreign keys // if so, we can stop the analyzer early func (a *analyzer) canShortCut(statement sqlparser.Statement) bool { @@ -595,6 +608,10 @@ type ShardedError struct { Inner error } +func (p ShardedError) Unwrap() error { + return p.Inner +} + func (p ShardedError) Error() string { return p.Inner.Error() } diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index 54ebec20a84..626ffe3f469 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -676,7 +676,10 @@ func TestGroupByBinding(t *testing.T) { TS1, }, { "select a.id from t as a, t1 group by id", - TS0, + // since we have authoritative info on t1, we know that it does have an `id` column, + // and we are missing column info for `t`, we just assume this is coming from t1. + // we really need schema tracking here + TS1, }, { "select a.id from t, t1 as a group by id", TS1, @@ -694,44 +697,47 @@ func TestGroupByBinding(t *testing.T) { func TestHavingBinding(t *testing.T) { tcases := []struct { - sql string - deps TableSet + sql, err string + deps TableSet }{{ - "select col from tabl having col = 1", - TS0, + sql: "select col from tabl having col = 1", + deps: TS0, }, { - "select col from tabl having tabl.col = 1", - TS0, + sql: "select col from tabl having tabl.col = 1", + deps: TS0, }, { - "select col from tabl having d.tabl.col = 1", - TS0, + sql: "select col from tabl having d.tabl.col = 1", + deps: TS0, }, { - "select tabl.col as x from tabl having x = 1", - TS0, + sql: "select tabl.col as x from tabl having col = 1", + deps: TS0, }, { - "select tabl.col as x from tabl having col", - TS0, + sql: "select tabl.col as x from tabl having x = 1", + deps: TS0, }, { - "select col from tabl having 1 = 1", - NoTables, + sql: "select tabl.col as x from tabl having col", + deps: TS0, }, { - "select col as c from tabl having c = 1", - TS0, + sql: "select col from tabl having 1 = 1", + deps: NoTables, }, { - "select 1 as c from tabl having c = 1", - NoTables, + sql: "select col as c from tabl having c = 1", + deps: TS0, }, { - "select t1.id from t1, t2 having id = 1", - TS0, + sql: "select 1 as c from tabl having c = 1", + deps: NoTables, }, { - "select t.id from t, t1 having id = 1", - TS0, + sql: "select t1.id from t1, t2 having id = 1", + deps: TS0, }, { - "select t.id, count(*) as a from t, t1 group by t.id having a = 1", - MergeTableSets(TS0, TS1), + sql: "select t.id from t, t1 having id = 1", + deps: TS0, }, { - "select t.id, sum(t2.name) as a from t, t2 group by t.id having a = 1", - TS1, + sql: "select t.id, count(*) as a from t, t1 group by t.id having a = 1", + deps: MergeTableSets(TS0, TS1), + }, { + sql: "select t.id, sum(t2.name) as a from t, t2 group by t.id having a = 1", + deps: TS1, }, { sql: "select u2.a, u1.a from u1, u2 having u2.a = 2", deps: TS1, diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index d5118371f60..ae95ebbf2f0 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -19,6 +19,8 @@ package semantics import ( "strings" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/sqlparser" ) @@ -197,28 +199,32 @@ func (b *binder) setSubQueryDependencies(subq *sqlparser.Subquery, currScope *sc } func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allowMulti, singleTableFallBack bool) (dependency, error) { + if !current.stmtScope && current.inGroupBy { + return b.resolveColInGroupBy(colName, current, allowMulti, singleTableFallBack) + } + if !current.stmtScope && current.inHaving && !current.inHavingAggr { + return b.resolveColumnInHaving(colName, current, allowMulti) + } + var thisDeps dependencies first := true var tableName *sqlparser.TableName + for current != nil { var err error thisDeps, err = b.resolveColumnInScope(current, colName, allowMulti) if err != nil { - err = makeAmbiguousError(colName, err) - if thisDeps == nil { - return dependency{}, err - } + return dependency{}, makeAmbiguousError(colName, err) } if !thisDeps.empty() { - deps, thisErr := thisDeps.get() - if thisErr != nil { - err = makeAmbiguousError(colName, thisErr) - } - return deps, err - } else if err != nil { - return dependency{}, err + deps, err := thisDeps.get() + return deps, makeAmbiguousError(colName, err) } - if current.parent == nil && len(current.tables) == 1 && first && colName.Qualifier.IsEmpty() && singleTableFallBack { + if current.parent == nil && + len(current.tables) == 1 && + first && + colName.Qualifier.IsEmpty() && + singleTableFallBack { // if this is the top scope, and we still haven't been able to find a match, we know we are about to fail // we can check this last scope and see if there is a single table. if there is just one table in the scope // we assume that the column is meant to come from this table. @@ -236,6 +242,146 @@ func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allow return dependency{}, ShardedError{&ColumnNotFoundError{Column: colName, Table: tableName}} } +func isColumnNotFound(err error) bool { + switch err := err.(type) { + case *ColumnNotFoundError: + return true + case ShardedError: + return isColumnNotFound(err.Inner) + default: + return false + } +} + +func (b *binder) resolveColumnInHaving(colName *sqlparser.ColName, current *scope, allowMulti bool) (dependency, error) { + if current.inHavingAggr { + // when inside an aggregation, we'll search the FROM clause before the SELECT expressions + deps, err := b.resolveColumn(colName, current.parent, allowMulti, true) + if !deps.direct.IsEmpty() || (err != nil && !isColumnNotFound(err)) { + return deps, err + } + } + + // Here we are searching among the SELECT expressions for a match + thisDeps, err := b.resolveColumnInScope(current, colName, allowMulti) + if err != nil { + return dependency{}, makeAmbiguousError(colName, err) + } + + if !thisDeps.empty() { + // we found something! let's return it + deps, err := thisDeps.get() + if err != nil { + err = makeAmbiguousError(colName, err) + } + return deps, err + } + + notFoundErr := &ColumnNotFoundClauseError{Column: colName.Name.String(), Clause: "having clause"} + if current.inHavingAggr { + // if we are inside an aggregation, we've already looked everywhere. now it's time to give up + return dependency{}, notFoundErr + } + + // Now we'll search the FROM clause, but with a twist. If we find it in the FROM clause, the column must also + // exist as a standalone expression in the SELECT list + deps, err := b.resolveColumn(colName, current.parent, allowMulti, true) + if deps.direct.IsEmpty() { + return dependency{}, notFoundErr + } + + sel := current.stmt.(*sqlparser.Select) // we can be sure of this, since HAVING doesn't exist on UNION + if selDeps := b.searchInSelectExpressions(colName, deps, sel); !selDeps.direct.IsEmpty() { + return selDeps, nil + } + + if !current.inHavingAggr && len(sel.GroupBy) == 0 { + // if we are not inside an aggregation, and there is no GROUP BY, we consider the FROM clause before failing + if !deps.direct.IsEmpty() || (err != nil && !isColumnNotFound(err)) { + return deps, err + } + } + + return dependency{}, notFoundErr +} + +// searchInSelectExpressions searches for the ColName among the SELECT and GROUP BY expressions +// It used dependency information to match the columns +func (b *binder) searchInSelectExpressions(colName *sqlparser.ColName, deps dependency, stmt *sqlparser.Select) dependency { + for _, selectExpr := range stmt.SelectExprs { + ae, ok := selectExpr.(*sqlparser.AliasedExpr) + if !ok { + continue + } + selectCol, ok := ae.Expr.(*sqlparser.ColName) + if !ok || !selectCol.Name.Equal(colName.Name) { + continue + } + + _, direct, _ := b.org.depsForExpr(selectCol) + if deps.direct == direct { + // we have found the ColName in the SELECT expressions, so it's safe to use here + direct, recursive, typ := b.org.depsForExpr(ae.Expr) + return dependency{certain: true, direct: direct, recursive: recursive, typ: typ} + } + } + + for _, gb := range stmt.GroupBy { + selectCol, ok := gb.(*sqlparser.ColName) + if !ok || !selectCol.Name.Equal(colName.Name) { + continue + } + + _, direct, _ := b.org.depsForExpr(selectCol) + if deps.direct == direct { + // we have found the ColName in the GROUP BY expressions, so it's safe to use here + direct, recursive, typ := b.org.depsForExpr(gb) + return dependency{certain: true, direct: direct, recursive: recursive, typ: typ} + } + } + return dependency{} +} + +// resolveColInGroupBy handles the special rules we have when binding on the GROUP BY column +func (b *binder) resolveColInGroupBy( + colName *sqlparser.ColName, + current *scope, + allowMulti bool, + singleTableFallBack bool, +) (dependency, error) { + if current.parent == nil { + return dependency{}, vterrors.VT13001("did not expect this to be the last scope") + } + // if we are in GROUP BY, we have to search the FROM clause before we search the SELECT expressions + deps, firstErr := b.resolveColumn(colName, current.parent, allowMulti, false) + if firstErr == nil { + return deps, nil + } + + // either we didn't find the column on a table, or it was ambiguous. + // in either case, next step is to search the SELECT expressions + if !colName.Qualifier.IsEmpty() { + // if the col name has a qualifier, none of the SELECT expressions are going to match + return dependency{}, nil + } + vtbl, ok := current.tables[0].(*vTableInfo) + if !ok { + return dependency{}, vterrors.VT13001("expected the table info to be a *vTableInfo") + } + + dependencies, err := vtbl.dependenciesInGroupBy(colName.Name.String(), b.org) + if err != nil { + return dependency{}, err + } + if dependencies.empty() { + if isColumnNotFound(firstErr) { + return dependency{}, &ColumnNotFoundClauseError{Column: colName.Name.String(), Clause: "group statement"} + } + return deps, firstErr + } + return dependencies.get() +} + func (b *binder) resolveColumnInScope(current *scope, expr *sqlparser.ColName, allowMulti bool) (dependencies, error) { var deps dependencies = ¬hing{} for _, table := range current.tables { diff --git a/go/vt/vtgate/semantics/dependencies.go b/go/vt/vtgate/semantics/dependencies.go index e68c5100ef5..714fa97c2c4 100644 --- a/go/vt/vtgate/semantics/dependencies.go +++ b/go/vt/vtgate/semantics/dependencies.go @@ -32,6 +32,7 @@ type ( merge(other dependencies, allowMulti bool) dependencies } dependency struct { + certain bool direct TableSet recursive TableSet typ evalengine.Type @@ -52,6 +53,7 @@ var ambigousErr = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "ambiguous") func createCertain(direct TableSet, recursive TableSet, qt evalengine.Type) *certain { c := &certain{ dependency: dependency{ + certain: true, direct: direct, recursive: recursive, }, @@ -65,6 +67,7 @@ func createCertain(direct TableSet, recursive TableSet, qt evalengine.Type) *cer func createUncertain(direct TableSet, recursive TableSet) *uncertain { return &uncertain{ dependency: dependency{ + certain: false, direct: direct, recursive: recursive, }, @@ -131,7 +134,7 @@ func (n *nothing) empty() bool { } func (n *nothing) get() (dependency, error) { - return dependency{}, nil + return dependency{certain: true}, nil } func (n *nothing) merge(d dependencies, _ bool) dependencies { diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 900f53603c6..6e71d6d81ea 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -34,6 +34,7 @@ type earlyRewriter struct { warning string expandedColumns map[sqlparser.TableName][]*sqlparser.ColName aliasMapCache map[*sqlparser.Select]map[string]exprContainer + tables *tableCollector // reAnalyze is used when we are running in the late stage, after the other parts of semantic analysis // have happened, and we are introducing or changing the AST. We invoke it so all parts of the query have been @@ -43,8 +44,6 @@ type earlyRewriter struct { func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { - case *sqlparser.Where: - return r.handleWhereClause(node, cursor.Parent()) case sqlparser.SelectExprs: return r.handleSelectExprs(cursor, node) case *sqlparser.JoinTableExpr: @@ -55,13 +54,6 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { rewriteAndExpr(cursor, node) case *sqlparser.NotExpr: rewriteNotExpr(cursor, node) - case sqlparser.GroupBy: - r.clause = "group clause" - iter := &exprIterator{ - node: node, - idx: -1, - } - return r.handleGroupBy(cursor.Parent(), iter) case *sqlparser.ComparisonExpr: return handleComparisonExpr(cursor, node) } @@ -74,6 +66,13 @@ func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case *sqlparser.JoinTableExpr: return r.handleJoinTableExprUp(node) + case sqlparser.GroupBy: + r.clause = "group clause" + iter := &exprIterator{ + node: node, + idx: -1, + } + return r.handleGroupBy(cursor.Parent(), iter) case sqlparser.OrderBy: r.clause = "order clause" iter := &orderByIterator{ @@ -82,6 +81,10 @@ func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { r: r, } return r.handleOrderBy(cursor.Parent(), iter) + case *sqlparser.Where: + if node.Type == sqlparser.HavingClause { + return r.handleHavingClause(node, cursor.Parent()) + } } return nil @@ -114,22 +117,18 @@ func rewriteNotExpr(cursor *sqlparser.Cursor, node *sqlparser.NotExpr) { cursor.Replace(cmp) } -// handleWhereClause processes WHERE clauses, specifically the HAVING clause. -func (r *earlyRewriter) handleWhereClause(node *sqlparser.Where, parent sqlparser.SQLNode) error { +// handleHavingClause processes the HAVING clause +func (r *earlyRewriter) handleHavingClause(node *sqlparser.Where, parent sqlparser.SQLNode) error { sel, ok := parent.(*sqlparser.Select) if !ok { return nil } - if node.Type != sqlparser.HavingClause { - return nil - } - expr, err := r.rewriteAliasesInHavingAndGroupBy(node.Expr, sel) + expr, err := r.rewriteAliasesInHaving(node.Expr, sel) if err != nil { return err } - node.Expr = expr - return nil + return r.reAnalyze(expr) } // handleSelectExprs expands * in SELECT expressions. @@ -208,7 +207,7 @@ func (r *earlyRewriter) replaceLiteralsInOrderBy(e sqlparser.Expr, iter iterator return false, nil } - newExpr, recheck, err := r.rewriteOrderByExpr(lit) + newExpr, recheck, err := r.rewriteOrderByLiteral(lit) if err != nil { return false, err } @@ -246,15 +245,15 @@ func (r *earlyRewriter) replaceLiteralsInOrderBy(e sqlparser.Expr, iter iterator return true, nil } -func (r *earlyRewriter) replaceLiteralsInGroupBy(e sqlparser.Expr, iter iterator) (bool, error) { +func (r *earlyRewriter) replaceLiteralsInGroupBy(e sqlparser.Expr) (sqlparser.Expr, error) { lit := getIntLiteral(e) if lit == nil { - return false, nil + return nil, nil } newExpr, err := r.rewriteGroupByExpr(lit) if err != nil { - return false, err + return nil, err } if getIntLiteral(newExpr) == nil { @@ -277,8 +276,7 @@ func (r *earlyRewriter) replaceLiteralsInGroupBy(e sqlparser.Expr, iter iterator newExpr = sqlparser.NewStrLiteral("") } - err = iter.replace(newExpr) - return true, err + return newExpr, nil } func getIntLiteral(e sqlparser.Expr) *sqlparser.Literal { @@ -344,19 +342,23 @@ func (r *earlyRewriter) handleGroupBy(parent sqlparser.SQLNode, iter iterator) e sel := sqlparser.GetFirstSelect(stmt) for e := iter.next(); e != nil; e = iter.next() { - lit, err := r.replaceLiteralsInGroupBy(e, iter) + expr, err := r.replaceLiteralsInGroupBy(e) if err != nil { return err } - if lit { - continue + if expr == nil { + expr, err = r.rewriteAliasesInGroupBy(e, sel) + if err != nil { + return err + } + } - expr, err := r.rewriteAliasesInHavingAndGroupBy(e, sel) + err = iter.replace(expr) if err != nil { return err } - err = iter.replace(expr) - if err != nil { + + if err = r.reAnalyze(expr); err != nil { return err } } @@ -370,35 +372,14 @@ func (r *earlyRewriter) handleGroupBy(parent sqlparser.SQLNode, iter iterator) e // in SELECT points to that expression, not any table column. // - However, if the aliased expression is an aggregation and the column identifier in // the HAVING/ORDER BY clause is inside an aggregation function, the rule does not apply. -func (r *earlyRewriter) rewriteAliasesInHavingAndGroupBy(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { +func (r *earlyRewriter) rewriteAliasesInGroupBy(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { type ExprContainer struct { expr sqlparser.Expr ambiguous bool } - aliases := map[string]ExprContainer{} - for _, e := range sel.SelectExprs { - ae, ok := e.(*sqlparser.AliasedExpr) - if !ok { - continue - } - - var alias string - - item := ExprContainer{expr: ae.Expr} - if ae.As.NotEmpty() { - alias = ae.As.Lowered() - } else if col, ok := ae.Expr.(*sqlparser.ColName); ok { - alias = col.Name.Lowered() - } - - if old, alreadyExists := aliases[alias]; alreadyExists && !sqlparser.Equals.Expr(old.expr, item.expr) { - item.ambiguous = true - } - - aliases[alias] = item - } - + currentScope := r.scoper.currentScope() + aliases := r.getAliasMap(sel) insideAggr := false downF := func(node, _ sqlparser.SQLNode) bool { switch node.(type) { @@ -426,43 +407,111 @@ func (r *earlyRewriter) rewriteAliasesInHavingAndGroupBy(node sqlparser.Expr, se break } + isColumnOnTable, sure := r.isColumnOnTable(col, currentScope) + if found && isColumnOnTable { + r.warning = fmt.Sprintf("Column '%s' in group statement is ambiguous", sqlparser.String(col)) + } + + if isColumnOnTable && sure { + break + } + + if !sure { + r.warning = "Missing table info, so not binding to anything on the FROM clause" + } + if item.ambiguous { err = &AmbiguousColumnError{Column: sqlparser.String(col)} + } else if insideAggr && sqlparser.ContainsAggregation(item.expr) { + err = &InvalidUseOfGroupFunction{} + } + if err != nil { cursor.StopTreeWalk() return } - if insideAggr && sqlparser.ContainsAggregation(item.expr) { - // I'm not sure about this, but my experiments point to this being the behaviour mysql has - // mysql> select min(name) as name from user order by min(name); - // 1 row in set (0.00 sec) - // - // mysql> select id % 2, min(name) as name from user group by id % 2 order by min(name); - // 2 rows in set (0.00 sec) - // - // mysql> select id % 2, 'foobar' as name from user group by id % 2 order by min(name); - // 2 rows in set (0.00 sec) - // - // mysql> select id % 2 from user group by id % 2 order by min(min(name)); - // ERROR 1111 (HY000): Invalid use of group function - // - // mysql> select id % 2, min(name) as k from user group by id % 2 order by min(k); - // ERROR 1111 (HY000): Invalid use of group function - // - // mysql> select id % 2, -id as name from user group by id % 2, -id order by min(name); - // 6 rows in set (0.01 sec) - break + cursor.Replace(sqlparser.CloneExpr(item.expr)) + } + }, nil) + + expr = output.(sqlparser.Expr) + return +} + +func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { + currentScope := r.scoper.currentScope() + if currentScope.isUnion { + // It is not safe to rewrite order by clauses in unions. + return node, nil + } + + aliases := r.getAliasMap(sel) + insideAggr := false + dontEnterSubquery := func(node, _ sqlparser.SQLNode) bool { + switch node.(type) { + case *sqlparser.Subquery: + return false + case sqlparser.AggrFunc: + insideAggr = true + } + + return true + } + output := sqlparser.CopyOnRewrite(node, dontEnterSubquery, func(cursor *sqlparser.CopyOnWriteCursor) { + var col *sqlparser.ColName + + switch node := cursor.Node().(type) { + case sqlparser.AggrFunc: + insideAggr = false + return + case *sqlparser.ColName: + col = node + default: + return + } + + if !col.Qualifier.IsEmpty() { + // we are only interested in columns not qualified by table names + return + } + + item, found := aliases[col.Name.Lowered()] + if insideAggr { + // inside aggregations, we want to first look for columns in the FROM clause + isColumnOnTable, sure := r.isColumnOnTable(col, currentScope) + if isColumnOnTable { + if found && sure { + r.warning = fmt.Sprintf("Column '%s' in having clause is ambiguous", sqlparser.String(col)) + } + return } + } else if !found { + // if outside aggregations, we don't care about FROM columns + // if there is no matching alias, there is no rewriting needed + return + } - cursor.Replace(sqlparser.CloneExpr(item.expr)) + // If we get here, it means we have found an alias and want to use it + if item.ambiguous { + err = &AmbiguousColumnError{Column: sqlparser.String(col)} + } else if insideAggr && sqlparser.ContainsAggregation(item.expr) { + err = &InvalidUseOfGroupFunction{} } + if err != nil { + cursor.StopTreeWalk() + return + } + + newColName := sqlparser.CopyOnRewrite(item.expr, nil, r.fillInQualifiers, nil) + + cursor.Replace(newColName) }, nil) expr = output.(sqlparser.Expr) return } -// rewriteAliasesInOrderBy rewrites columns in the ORDER BY and HAVING clauses to use aliases +// rewriteAliasesInOrderBy rewrites columns in the ORDER BY to use aliases // from the SELECT expressions when applicable, following MySQL scoping rules: // - A column identifier without a table qualifier that matches an alias introduced // in SELECT points to that expression, not any table column. @@ -485,8 +534,7 @@ func (r *earlyRewriter) rewriteAliasesInOrderBy(node sqlparser.Expr, sel *sqlpar insideAggr = true } - _, isSubq := node.(*sqlparser.Subquery) - return !isSubq + return true } output := sqlparser.CopyOnRewrite(node, dontEnterSubquery, func(cursor *sqlparser.CopyOnWriteCursor) { var col *sqlparser.ColName @@ -506,18 +554,29 @@ func (r *earlyRewriter) rewriteAliasesInOrderBy(node sqlparser.Expr, sel *sqlpar return } - item, found := aliases[col.Name.Lowered()] + var item exprContainer + var found bool + + item, found = aliases[col.Name.Lowered()] if !found { // if there is no matching alias, there is no rewriting needed return } + isColumnOnTable, sure := r.isColumnOnTable(col, currentScope) + if found && isColumnOnTable && sure { + r.warning = fmt.Sprintf("Column '%s' in order by statement is ambiguous", sqlparser.String(col)) + } topLevel := col == node - if !topLevel && r.isColumnOnTable(col, currentScope) { + if isColumnOnTable && sure && !topLevel { // we only want to replace columns that are not coming from the table return } + if !sure { + r.warning = "Missing table info, so not binding to anything on the FROM clause" + } + if item.ambiguous { err = &AmbiguousColumnError{Column: sqlparser.String(col)} } else if insideAggr && sqlparser.ContainsAggregation(item.expr) { @@ -528,19 +587,42 @@ func (r *earlyRewriter) rewriteAliasesInOrderBy(node sqlparser.Expr, sel *sqlpar return } - cursor.Replace(sqlparser.CloneExpr(item.expr)) + newColName := sqlparser.CopyOnRewrite(item.expr, nil, r.fillInQualifiers, nil) + + cursor.Replace(newColName) }, nil) expr = output.(sqlparser.Expr) return } -func (r *earlyRewriter) isColumnOnTable(col *sqlparser.ColName, currentScope *scope) bool { +// fillInQualifiers adds qualifiers to any columns we have rewritten +func (r *earlyRewriter) fillInQualifiers(cursor *sqlparser.CopyOnWriteCursor) { + col, ok := cursor.Node().(*sqlparser.ColName) + if !ok || !col.Qualifier.IsEmpty() { + return + } + ts, found := r.binder.direct[col] + if !found { + panic("uh oh") + } + tbl := r.tables.Tables[ts.TableOffset()] + tblName, err := tbl.Name() + if err != nil { + panic(err) + } + cursor.Replace(sqlparser.NewColNameWithQualifier(col.Name.String(), tblName)) +} + +func (r *earlyRewriter) isColumnOnTable(col *sqlparser.ColName, currentScope *scope) (isColumn bool, isCertain bool) { if !currentScope.stmtScope && currentScope.parent != nil { currentScope = currentScope.parent } - _, err := r.binder.resolveColumn(col, currentScope, false, false) - return err == nil + deps, err := r.binder.resolveColumn(col, currentScope, false, false) + if err != nil { + return false, true + } + return true, deps.certain } func (r *earlyRewriter) getAliasMap(sel *sqlparser.Select) (aliases map[string]exprContainer) { @@ -579,7 +661,7 @@ type exprContainer struct { ambiguous bool } -func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (expr sqlparser.Expr, needReAnalysis bool, err error) { +func (r *earlyRewriter) rewriteOrderByLiteral(node *sqlparser.Literal) (expr sqlparser.Expr, needReAnalysis bool, err error) { scope, found := r.scoper.specialExprScopes[node] if !found { return node, false, nil diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index 49fa09897cf..9aee897d8d2 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -172,8 +172,8 @@ func TestExpandStar(t *testing.T) { sql: "select * from t1 join t5 using (b) having b = 12", expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b having t1.b = 12", }, { - sql: "select 1 from t1 join t5 using (b) having b = 12", - expSQL: "select 1 from t1 join t5 on t1.b = t5.b having t1.b = 12", + sql: "select 1 from t1 join t5 using (b) where b = 12", + expSQL: "select 1 from t1 join t5 on t1.b = t5.b where t1.b = 12", }, { sql: "select * from (select 12) as t", expSQL: "select t.`12` from (select 12 from dual) as t", @@ -304,6 +304,90 @@ func TestRewriteJoinUsingColumns(t *testing.T) { } +func TestGroupByColumnName(t *testing.T) { + schemaInfo := &FakeSI{ + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("id"), + Type: sqltypes.Int32, + }, { + Name: sqlparser.NewIdentifierCI("col1"), + Type: sqltypes.Int32, + }}, + ColumnListAuthoritative: true, + }, + "t2": { + Name: sqlparser.NewIdentifierCS("t2"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("id"), + Type: sqltypes.Int32, + }, { + Name: sqlparser.NewIdentifierCI("col2"), + Type: sqltypes.Int32, + }}, + ColumnListAuthoritative: true, + }, + }, + } + cDB := "db" + tcases := []struct { + sql string + expSQL string + expDeps TableSet + expErr string + warning string + }{{ + sql: "select t3.col from t3 group by kj", + expSQL: "select t3.col from t3 group by kj", + expDeps: TS0, + }, { + sql: "select t2.col2 as xyz from t2 group by xyz", + expSQL: "select t2.col2 as xyz from t2 group by t2.col2", + expDeps: TS0, + }, { + sql: "select id from t1 group by unknown", + expErr: "Unknown column 'unknown' in 'group statement'", + }, { + sql: "select t1.c as x, sum(t2.id) as x from t1 join t2 group by x", + expErr: "VT03005: cannot group on 'x'", + }, { + sql: "select t1.col1, sum(t2.id) as col1 from t1 join t2 group by col1", + expSQL: "select t1.col1, sum(t2.id) as col1 from t1 join t2 group by col1", + expDeps: TS0, + warning: "Column 'col1' in group statement is ambiguous", + }, { + sql: "select t2.col2 as id, sum(t2.id) as x from t1 join t2 group by id", + expSQL: "select t2.col2 as id, sum(t2.id) as x from t1 join t2 group by t2.col2", + expDeps: TS1, + }, { + sql: "select sum(t2.col2) as id, sum(t2.id) as x from t1 join t2 group by id", + expErr: "VT03005: cannot group on 'id'", + }, { + sql: "select count(*) as x from t1 group by x", + expErr: "VT03005: cannot group on 'x'", + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.Parse(tcase.sql) + require.NoError(t, err) + selectStatement := ast.(*sqlparser.Select) + st, err := AnalyzeStrict(selectStatement, cDB, schemaInfo) + if tcase.expErr == "" { + require.NoError(t, err) + assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + gb := selectStatement.GroupBy + deps := st.RecursiveDeps(gb[0]) + assert.Equal(t, tcase.expDeps, deps) + assert.Equal(t, tcase.warning, st.Warning) + } else { + require.EqualError(t, err, tcase.expErr) + } + }) + } +} + func TestGroupByLiteral(t *testing.T) { schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{}, @@ -432,33 +516,89 @@ func TestOrderByLiteral(t *testing.T) { } func TestHavingColumnName(t *testing.T) { - schemaInfo := &FakeSI{ - Tables: map[string]*vindexes.Table{}, - } + schemaInfo := getSchemaWithKnownColumns() cDB := "db" tcases := []struct { - sql string - expSQL string - expErr string + sql string + expSQL string + expDeps TableSet + expErr string + warning string }{{ - sql: "select id, sum(foo) as sumOfFoo from t1 having sumOfFoo > 1", - expSQL: "select id, sum(foo) as sumOfFoo from t1 having sum(foo) > 1", + sql: "select id, sum(foo) as sumOfFoo from t1 having sumOfFoo > 1", + expSQL: "select id, sum(foo) as sumOfFoo from t1 having sum(t1.foo) > 1", + expDeps: TS0, + }, { + sql: "select id as X, sum(foo) as X from t1 having X > 1", + expErr: "Column 'X' in field list is ambiguous", + }, { + sql: "select id, sum(t1.foo) as foo from t1 having sum(foo) > 1", + expSQL: "select id, sum(t1.foo) as foo from t1 having sum(foo) > 1", + expDeps: TS0, + warning: "Column 'foo' in having clause is ambiguous", + }, { + sql: "select id, sum(t1.foo) as XYZ from t1 having sum(XYZ) > 1", + expErr: "Invalid use of group function", + }, { + sql: "select foo + 2 as foo from t1 having foo = 42", + expSQL: "select foo + 2 as foo from t1 having t1.foo + 2 = 42", + expDeps: TS0, + }, { + sql: "select count(*), ename from emp group by ename having comm > 1000", + expErr: "Unknown column 'comm' in 'having clause'", + }, { + sql: "select sal, ename from emp having empno > 1000", + expSQL: "select sal, ename from emp having empno > 1000", + expDeps: TS0, + }, { + sql: "select foo, count(*) foo from t1 group by foo having foo > 1000", + expErr: "Column 'foo' in field list is ambiguous", + }, { + sql: "select foo, count(*) foo from t1, emp group by foo having sum(sal) > 1000", + expSQL: "select foo, count(*) as foo from t1, emp group by foo having sum(sal) > 1000", + expDeps: TS1, + warning: "Column 'foo' in group statement is ambiguous", + }, { + sql: "select foo as X, sal as foo from t1, emp having sum(X) > 1000", + expSQL: "select foo as X, sal as foo from t1, emp having sum(t1.foo) > 1000", + expDeps: TS0, + }, { + sql: "select count(*) a from someTable having a = 10", + expSQL: "select count(*) as a from someTable having count(*) = 10", + expDeps: TS0, + }, { + sql: "select count(*) from emp having ename = 10", + expSQL: "select count(*) from emp having ename = 10", + expDeps: TS0, + }, { + sql: "select sum(sal) empno from emp where ename > 0 having empno = 2", + expSQL: "select sum(sal) as empno from emp where ename > 0 having sum(emp.sal) = 2", + expDeps: TS0, }, { - sql: "select id, sum(foo) as foo from t1 having sum(foo) > 1", - expSQL: "select id, sum(foo) as foo from t1 having sum(foo) > 1", + // test with missing schema info + sql: "select foo, count(bar) as x from someTable group by foo having id > avg(baz)", + expErr: "Unknown column 'id' in 'having clause'", }, { - sql: "select foo + 2 as foo from t1 having foo = 42", - expSQL: "select foo + 2 as foo from t1 having foo + 2 = 42", + sql: "select t1.foo as alias, count(bar) as x from t1 group by foo having foo+54 = 56", + expSQL: "select t1.foo as alias, count(bar) as x from t1 group by foo having foo + 54 = 56", + expDeps: TS0, + }, { + sql: "select 1 from t1 group by foo having foo = 1 and count(*) > 1", + expSQL: "select 1 from t1 group by foo having foo = 1 and count(*) > 1", + expDeps: TS0, }} + for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.Parse(tcase.sql) require.NoError(t, err) - selectStatement := ast.(sqlparser.SelectStatement) - _, err = Analyze(selectStatement, cDB, schemaInfo) + selectStatement := ast.(*sqlparser.Select) + semTbl, err := AnalyzeStrict(selectStatement, cDB, schemaInfo) if tcase.expErr == "" { require.NoError(t, err) assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + assert.Equal(t, tcase.expDeps, semTbl.RecursiveDeps(selectStatement.Having.Expr)) + assert.Equal(t, tcase.warning, semTbl.Warning, "warning") } else { require.EqualError(t, err, tcase.expErr) } @@ -466,7 +606,7 @@ func TestHavingColumnName(t *testing.T) { } } -func TestOrderByColumnName(t *testing.T) { +func getSchemaWithKnownColumns() *FakeSI { schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{ "t1": { @@ -484,66 +624,114 @@ func TestOrderByColumnName(t *testing.T) { }}, ColumnListAuthoritative: true, }, + "emp": { + Keyspace: &vindexes.Keyspace{Name: "ks", Sharded: true}, + Name: sqlparser.NewIdentifierCS("emp"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("empno"), + Type: sqltypes.Int64, + }, { + Name: sqlparser.NewIdentifierCI("ename"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewIdentifierCI("sal"), + Type: sqltypes.Int64, + }}, + ColumnListAuthoritative: true, + }, }, } + return schemaInfo +} + +func TestOrderByColumnName(t *testing.T) { + schemaInfo := getSchemaWithKnownColumns() cDB := "db" tcases := []struct { - sql string - expSQL string - expErr string + sql string + expSQL string + expErr string + warning string + deps TableSet }{{ sql: "select id, sum(foo) as sumOfFoo from t1 order by sumOfFoo", - expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(foo) asc", + expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(t1.foo) asc", + deps: TS0, }, { sql: "select id, sum(foo) as sumOfFoo from t1 order by sumOfFoo + 1", - expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(foo) + 1 asc", + expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(t1.foo) + 1 asc", + deps: TS0, }, { sql: "select id, sum(foo) as sumOfFoo from t1 order by abs(sumOfFoo)", - expSQL: "select id, sum(foo) as sumOfFoo from t1 order by abs(sum(foo)) asc", + expSQL: "select id, sum(foo) as sumOfFoo from t1 order by abs(sum(t1.foo)) asc", + deps: TS0, }, { sql: "select id, sum(foo) as sumOfFoo from t1 order by max(sumOfFoo)", expErr: "Invalid use of group function", }, { - sql: "select id, sum(foo) as foo from t1 order by foo + 1", - expSQL: "select id, sum(foo) as foo from t1 order by foo + 1 asc", + sql: "select id, sum(foo) as foo from t1 order by foo + 1", + expSQL: "select id, sum(foo) as foo from t1 order by foo + 1 asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { - sql: "select id, sum(foo) as foo from t1 order by foo", - expSQL: "select id, sum(foo) as foo from t1 order by sum(foo) asc", + sql: "select id, sum(foo) as foo from t1 order by foo", + expSQL: "select id, sum(foo) as foo from t1 order by sum(t1.foo) asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { - sql: "select id, lower(min(foo)) as foo from t1 order by min(foo)", - expSQL: "select id, lower(min(foo)) as foo from t1 order by min(foo) asc", + sql: "select id, lower(min(foo)) as foo from t1 order by min(foo)", + expSQL: "select id, lower(min(foo)) as foo from t1 order by min(foo) asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { - sql: "select id, lower(min(foo)) as foo from t1 order by foo", - expSQL: "select id, lower(min(foo)) as foo from t1 order by lower(min(foo)) asc", + sql: "select id, lower(min(foo)) as foo from t1 order by foo", + expSQL: "select id, lower(min(foo)) as foo from t1 order by lower(min(t1.foo)) asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { - sql: "select id, lower(min(foo)) as foo from t1 order by abs(foo)", - expSQL: "select id, lower(min(foo)) as foo from t1 order by abs(foo) asc", + sql: "select id, lower(min(foo)) as foo from t1 order by abs(foo)", + expSQL: "select id, lower(min(foo)) as foo from t1 order by abs(foo) asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { - sql: "select id, t1.bar as foo from t1 group by id order by min(foo)", - expSQL: "select id, t1.bar as foo from t1 group by id order by min(foo) asc", + sql: "select id, t1.bar as foo from t1 group by id order by min(foo)", + expSQL: "select id, t1.bar as foo from t1 group by id order by min(foo) asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { sql: "select id, bar as id, count(*) from t1 order by id", expErr: "Column 'id' in field list is ambiguous", }, { - sql: "select id, id, count(*) from t1 order by id", - expSQL: "select id, id, count(*) from t1 order by id asc", + sql: "select id, id, count(*) from t1 order by id", + expSQL: "select id, id, count(*) from t1 order by t1.id asc", + deps: TS0, + warning: "Column 'id' in order by statement is ambiguous", }, { - sql: "select id, count(distinct foo) k from t1 group by id order by k", - expSQL: "select id, count(distinct foo) as k from t1 group by id order by count(distinct foo) asc", + sql: "select id, count(distinct foo) k from t1 group by id order by k", + expSQL: "select id, count(distinct foo) as k from t1 group by id order by count(distinct t1.foo) asc", + deps: TS0, + warning: "Column 'id' in group statement is ambiguous", }, { sql: "select user.id as foo from user union select col from user_extra order by foo", expSQL: "select `user`.id as foo from `user` union select col from user_extra order by foo asc", - }, - } + deps: MergeTableSets(TS0, TS1), + }, { + sql: "select foo as X, sal as foo from t1, emp order by sum(X)", + expSQL: "select foo as X, sal as foo from t1, emp order by sum(t1.foo) asc", + deps: TS0, + }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.Parse(tcase.sql) require.NoError(t, err) selectStatement := ast.(sqlparser.SelectStatement) - _, err = Analyze(selectStatement, cDB, schemaInfo) + semTable, err := AnalyzeStrict(selectStatement, cDB, schemaInfo) if tcase.expErr == "" { require.NoError(t, err) assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + orderByExpr := selectStatement.GetOrderBy()[0].Expr + assert.Equal(t, tcase.deps, semTable.RecursiveDeps(orderByExpr)) + assert.Equal(t, tcase.warning, semTable.Warning) } else { require.EqualError(t, err, tcase.expErr) } diff --git a/go/vt/vtgate/semantics/errors.go b/go/vt/vtgate/semantics/errors.go index 6a5a7e29082..6e66a806543 100644 --- a/go/vt/vtgate/semantics/errors.go +++ b/go/vt/vtgate/semantics/errors.go @@ -52,6 +52,7 @@ type ( SubqueryColumnCountError struct{ Expected int } ColumnsMissingInSchemaError struct{} InvalidUseOfGroupFunction struct{} + CantGroupOn struct{ Column string } UnsupportedMultiTablesInUpdateError struct { ExprCount int @@ -65,6 +66,10 @@ type ( Column *sqlparser.ColName Table *sqlparser.TableName } + ColumnNotFoundClauseError struct { + Column string + Clause string + } ) func eprintf(e error, format string, args ...any) string { @@ -264,14 +269,40 @@ func (e *ColumnsMissingInSchemaError) ErrorCode() vtrpcpb.Code { } // InvalidUserOfGroupFunction -func (InvalidUseOfGroupFunction) Error() string { +func (*InvalidUseOfGroupFunction) Error() string { return "Invalid use of group function" } -func (InvalidUseOfGroupFunction) ErrorCode() vtrpcpb.Code { +func (*InvalidUseOfGroupFunction) ErrorCode() vtrpcpb.Code { return vtrpcpb.Code_INVALID_ARGUMENT } -func (InvalidUseOfGroupFunction) ErrorState() vterrors.State { +func (*InvalidUseOfGroupFunction) ErrorState() vterrors.State { return vterrors.InvalidGroupFuncUse } + +// CantGroupOn +func (e *CantGroupOn) Error() string { + return vterrors.VT03005(e.Column).Error() +} + +func (*CantGroupOn) ErrorCode() vtrpcpb.Code { + return vtrpcpb.Code_INVALID_ARGUMENT +} + +func (e *CantGroupOn) ErrorState() vterrors.State { + return vterrors.VT03005(e.Column).State +} + +// ColumnNotFoundInGroupByError +func (e *ColumnNotFoundClauseError) Error() string { + return fmt.Sprintf("Unknown column '%s' in '%s'", e.Column, e.Clause) +} + +func (*ColumnNotFoundClauseError) ErrorCode() vtrpcpb.Code { + return vtrpcpb.Code_INVALID_ARGUMENT +} + +func (e *ColumnNotFoundClauseError) ErrorState() vterrors.State { + return vterrors.BadFieldError +} diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index 5d27b31b84e..107966006f4 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -40,12 +40,15 @@ type ( } scope struct { - parent *scope - stmt sqlparser.Statement - tables []TableInfo - isUnion bool - joinUsing map[string]TableSet - stmtScope bool + parent *scope + stmt sqlparser.Statement + tables []TableInfo + isUnion bool + joinUsing map[string]TableSet + stmtScope bool + inGroupBy bool + inHaving bool + inHavingAggr bool } ) @@ -73,11 +76,20 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { return s.addColumnInfoForOrderBy(cursor, node) case sqlparser.GroupBy: return s.addColumnInfoForGroupBy(cursor, node) - case *sqlparser.Where: - if node.Type != sqlparser.HavingClause { + case sqlparser.AggrFunc: + if !s.currentScope().inHaving { break } - return s.createSpecialScopePostProjection(cursor.Parent()) + s.currentScope().inHavingAggr = true + case *sqlparser.Where: + if node.Type == sqlparser.HavingClause { + err := s.createSpecialScopePostProjection(cursor.Parent()) + if err != nil { + return err + } + s.currentScope().inHaving = true + return nil + } } return nil } @@ -87,10 +99,12 @@ func (s *scoper) addColumnInfoForGroupBy(cursor *sqlparser.Cursor, node sqlparse if err != nil { return err } + currentScope := s.currentScope() + currentScope.inGroupBy = true for _, expr := range node { lit := keepIntLiteral(expr) if lit != nil { - s.specialExprScopes[lit] = s.currentScope() + s.specialExprScopes[lit] = currentScope } } return nil @@ -194,6 +208,8 @@ func (s *scoper) up(cursor *sqlparser.Cursor) error { break } s.popScope() + case sqlparser.AggrFunc: + s.currentScope().inHavingAggr = false case sqlparser.TableExpr: if isParentSelect(cursor) { curScope := s.currentScope() diff --git a/go/vt/vtgate/semantics/vtable.go b/go/vt/vtgate/semantics/vtable.go index ce7efe22371..ebb72e87b2a 100644 --- a/go/vt/vtgate/semantics/vtable.go +++ b/go/vt/vtgate/semantics/vtable.go @@ -42,10 +42,25 @@ func (v *vTableInfo) dependencies(colName string, org originable) (dependencies, if name != colName { continue } - directDeps, recursiveDeps, qt := org.depsForExpr(v.cols[i]) + deps = deps.merge(v.createCertainForCol(org, i), false) + } + if deps.empty() && v.hasStar() { + return createUncertain(v.tables, v.tables), nil + } + return deps, nil +} - newDeps := createCertain(directDeps, recursiveDeps, qt) - deps = deps.merge(newDeps, false) +func (v *vTableInfo) dependenciesInGroupBy(colName string, org originable) (dependencies, error) { + // this method is consciously very similar to vTableInfo.dependencies and should remain so + var deps dependencies = ¬hing{} + for i, name := range v.columnNames { + if name != colName { + continue + } + if sqlparser.ContainsAggregation(v.cols[i]) { + return nil, &CantGroupOn{name} + } + deps = deps.merge(v.createCertainForCol(org, i), false) } if deps.empty() && v.hasStar() { return createUncertain(v.tables, v.tables), nil @@ -53,6 +68,12 @@ func (v *vTableInfo) dependencies(colName string, org originable) (dependencies, return deps, nil } +func (v *vTableInfo) createCertainForCol(org originable, i int) *certain { + directDeps, recursiveDeps, qt := org.depsForExpr(v.cols[i]) + newDeps := createCertain(directDeps, recursiveDeps, qt) + return newDeps +} + // IsInfSchema implements the TableInfo interface func (v *vTableInfo) IsInfSchema() bool { return false