Skip to content

Commit

Permalink
planner: support grouping function/col/expression rewriting and physi…
Browse files Browse the repository at this point in the history
…cal plan exhaustion for rollup expand OP (#44488)

close #44487
  • Loading branch information
AilinKid authored Jun 12, 2023
1 parent c18e60f commit 465bd60
Show file tree
Hide file tree
Showing 23 changed files with 560 additions and 41 deletions.
2 changes: 2 additions & 0 deletions errno/errcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,8 @@ const (
ErrWindowNoGroupOrderUnused = 3597
ErrWindowExplainJSON = 3598
ErrWindowFunctionIgnoresFrame = 3599
ErrInvalidNumberOfArgs = 3601
ErrFieldInGroupingNotGroupBy = 3602
ErrIllegalPrivilegeLevel = 3619
ErrCTEMaxRecursionDepth = 3636
ErrNotHintUpdatable = 3637
Expand Down
2 changes: 2 additions & 0 deletions errno/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,8 @@ var MySQLErrName = map[uint16]*mysql.ErrMessage{
ErrWindowNoGroupOrderUnused: mysql.Message("ASC or DESC with GROUP BY isn't allowed with window functions; put ASC or DESC in ORDER BY", nil),
ErrWindowExplainJSON: mysql.Message("To get information about window functions use EXPLAIN FORMAT=JSON", nil),
ErrWindowFunctionIgnoresFrame: mysql.Message("Window function '%s' ignores the frame clause of window '%s' and aggregates over the whole partition", nil),
ErrInvalidNumberOfArgs: mysql.Message("Too many arguments for function %s; maximum allowed is %d", nil),
ErrFieldInGroupingNotGroupBy: mysql.Message("Argument %s of GROUPING function is not in GROUP BY", nil),
ErrRoleNotGranted: mysql.Message("%s is not granted to %s", nil),
ErrMaxExecTimeExceeded: mysql.Message("Query execution was interrupted, maximum statement execution time exceeded", nil),
ErrLockAcquireFailAndNoWaitSet: mysql.Message("Statement aborted because lock(s) could not be acquired immediately and NOWAIT is set.", nil),
Expand Down
10 changes: 10 additions & 0 deletions errors.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2436,6 +2436,16 @@ error = '''
Window function '%s' ignores the frame clause of window '%s' and aggregates over the whole partition
'''

["planner:3601"]
error = '''
Too many arguments for function %s; maximum allowed is %d
'''

["planner:3602"]
error = '''
Argument %s of GROUPING function is not in GROUP BY
'''

["planner:3637"]
error = '''
Variable '%s' cannot be set using SET_VAR hint.
Expand Down
10 changes: 10 additions & 0 deletions expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ func (col *CorrelatedColumn) VecEvalJSON(ctx sessionctx.Context, input *chunk.Ch
return genVecFromConstExpr(ctx, col, types.ETJson, input, result)
}

// Traverse implements the TraverseDown interface.
func (col *CorrelatedColumn) Traverse(action TraverseAction) Expression {
return action.Transform(col)
}

// Eval implements Expression interface.
func (col *CorrelatedColumn) Eval(row chunk.Row) (types.Datum, error) {
return *col.Data, nil
Expand Down Expand Up @@ -398,6 +403,11 @@ func (col *Column) GetType() *types.FieldType {
return col.RetType
}

// Traverse implements the TraverseDown interface.
func (col *Column) Traverse(action TraverseAction) Expression {
return action.Transform(col)
}

// Eval implements Expression interface.
func (col *Column) Eval(row chunk.Row) (types.Datum, error) {
return row.GetDatum(col.Index, col.RetType), nil
Expand Down
5 changes: 5 additions & 0 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ func (c *Constant) getLazyDatum(row chunk.Row) (dt types.Datum, isLazy bool, err
return types.Datum{}, false, nil
}

// Traverse implements the TraverseDown interface.
func (c *Constant) Traverse(action TraverseAction) Expression {
return action.Transform(c)
}

// Eval implements Expression interface.
func (c *Constant) Eval(row chunk.Row) (types.Datum, error) {
if dt, lazy, err := c.getLazyDatum(row); lazy {
Expand Down
9 changes: 9 additions & 0 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ type ReverseExpr interface {
ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error)
}

// TraverseAction define the interface for action when traversing down an expression.
type TraverseAction interface {
Transform(Expression) Expression
}

// Expression represents all scalar expression in SQL.
type Expression interface {
fmt.Stringer
Expand All @@ -105,6 +110,8 @@ type Expression interface {
ReverseExpr
CollationInfo

Traverse(TraverseAction) Expression

// Eval evaluates an expression through a row.
Eval(row chunk.Row) (types.Datum, error)

Expand Down Expand Up @@ -1269,6 +1276,8 @@ func scalarExprSupportedByFlash(function *ScalarFunction) bool {
return true
case ast.IsIPv4, ast.IsIPv6:
return true
case ast.Grouping: // grouping function for grouping sets identification.
return true
}
return false
}
Expand Down
7 changes: 3 additions & 4 deletions expression/grouping_sets.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,12 +610,11 @@ func (gss GroupingSets) DistinctSizeWithThreshold(N int) (int, []uint64, map[int
// for every original column unique, traverse the all grouping set.
for idx, oneOriginSetIDs := range originGroupingIDsSlice {
if oneOriginSetIDs.Has(i) {
// this column is needed in this grouping set.
continue
// this column is needed in this grouping set. maintaining the map.
collectionMap[gids[idx]] = struct{}{}
}
// this column is not needed in this grouping set.(this column is grouped)
collectionMap[gids[idx]] = struct{}{}
}
// id2GIDs maintained the needed-column's grouping sets (GIDs)
id2GIDs[i] = collectionMap
})
return len(distinctGroupingIDsPos), gids, id2GIDs
Expand Down
26 changes: 13 additions & 13 deletions expression/grouping_sets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,25 +414,25 @@ func TestDistinctGroupingSets(t *testing.T) {
// for every col id, mapping them to a slice of gid.
require.Equal(t, len(id2Gids), 3)
// 0 --> when grouping(column#1), the corresponding affected gids should be {0}
// +--- explanation: when grouping(a), a is only grouped when grouping id = 0.
require.Equal(t, len(id2Gids[1]), 1)
_, ok := id2Gids[1][0]
// +--- explanation: when grouping(a), col-a is needed when grouping id = 1,2,3.
require.Equal(t, len(id2Gids[1]), 3)
_, ok := id2Gids[1][1]
require.Equal(t, ok, true)
_, ok = id2Gids[1][2]
require.Equal(t, ok, true)
_, ok = id2Gids[1][3]
require.Equal(t, ok, true)
// 1 --> when grouping(column#2), the corresponding affected gids should be {0,1}
// +--- explanation: when grouping(b), b is only grouped when grouping id = 0 or 1.
// +--- explanation: when grouping(b), col-b is needed when grouping id = 2 or 3.
require.Equal(t, len(id2Gids[2]), 2)
_, ok = id2Gids[2][0]
_, ok = id2Gids[2][2]
require.Equal(t, ok, true)
_, ok = id2Gids[2][1]
_, ok = id2Gids[2][3]
require.Equal(t, ok, true)
// 2 --> when grouping(column#3), the corresponding affected gids should be {0,1,2}
// +--- explanation: when grouping(c), c is only grouped when grouping id = 0 or 1 or 2.
require.Equal(t, len(id2Gids[3]), 3)
_, ok = id2Gids[3][0]
require.Equal(t, ok, true)
_, ok = id2Gids[3][1]
require.Equal(t, ok, true)
_, ok = id2Gids[3][2]
// +--- explanation: when grouping(c), col-c is needed when grouping id = 3.
require.Equal(t, len(id2Gids[3]), 1)
_, ok = id2Gids[3][3]
require.Equal(t, ok, true)
// column d is not in the grouping set columns, so it won't be here.
}
5 changes: 5 additions & 0 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,11 @@ func (sf *ScalarFunction) Decorrelate(schema *Schema) Expression {
return sf
}

// Traverse implements the TraverseDown interface.
func (sf *ScalarFunction) Traverse(action TraverseAction) Expression {
return action.Transform(sf)
}

// Eval implements Expression interface.
func (sf *ScalarFunction) Eval(row chunk.Row) (d types.Datum, err error) {
var (
Expand Down
3 changes: 3 additions & 0 deletions expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,3 +597,6 @@ func (m *MockExpr) SetCharsetAndCollation(chs, coll string) {}
func (m *MockExpr) MemoryUsage() (sum int64) {
return
}
func (m *MockExpr) Traverse(action TraverseAction) Expression {
return action.Transform(m)
}
47 changes: 47 additions & 0 deletions planner/core/casetest/enforce_mpp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,3 +646,50 @@ func TestMPPSharedCTEScan(t *testing.T) {
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}

func TestRollupMPP(t *testing.T) {
store := testkit.CreateMockStore(t, internal.WithMockTiFlash(2))
tk := testkit.NewTestKit(t, store)

tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("drop table if exists s")
tk.MustExec("create table t(a int, b int, c int)")
tk.MustExec("create table s(a int, b int, c int)")
tk.MustExec("alter table t set tiflash replica 1")
tk.MustExec("alter table s set tiflash replica 1")

tb := external.GetTableByName(t, tk, "test", "t")
err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)

tb = external.GetTableByName(t, tk, "test", "s")
err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)

var input []string
var output []struct {
SQL string
Plan []string
Warn []string
}

tk.MustExec("set @@tidb_enforce_mpp='on'")
tk.Session().GetSessionVars().TiFlashFineGrainedShuffleStreamCount = -1

enforceMPPSuiteData := GetEnforceMPPSuiteData()
enforceMPPSuiteData.LoadTestCases(t, &input, &output)
for i, tt := range input {
testdata.OnRecord(func() {
output[i].SQL = tt
})
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}
12 changes: 12 additions & 0 deletions planner/core/casetest/testdata/enforce_mpp_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,17 @@
// The outer one will fail to use MPP. Since the inner one is references the outer one, the whole SQL cannot MPP.
"explain format = 'brief' with c1 as (select /*+ read_from_storage(tikv[t]) */ * from t), c2 as (select c1.* from c1, c1 c2 where c1.b=c2.c) select * from c2 c1, c2, (with c3 as (select * from c1) select c3.* from c3, c3 c4 where c3.c=c4.b) c3 where c1.a=c2.b and c1.a=c3.a"
]
},
{
"name": "TestRollupMPP",
"cases": [
"explain format = 'brief' select count(1) from t group by a, b with rollup; -- 1. simple agg",
"explain format = 'brief' select sum(c), count(1) from t group by a, b with rollup; -- 2. non-grouping set col c",
"explain format = 'brief' select count(a) from t group by a, b with rollup; -- 3. should keep the original col a",
"explain format = 'brief' select grouping(a) from t group by a, b with rollup; -- 4. contain grouping function ref to grouping set column a",
"explain format = 'brief' select grouping(a,b) from t group by a, b with rollup; -- 5. grouping function contains grouping set column a,c",
"explain format = 'brief' select a, grouping(b,a) from t group by a,b with rollup; -- 6. resolve normal column a to grouping set column a'",
"explain format = 'brief' select a+1, grouping(b) from t group by a+1, b with rollup; -- 7. resolve field list a+1 to grouping set column a+1"
]
}
]
Loading

0 comments on commit 465bd60

Please sign in to comment.