Skip to content

Commit

Permalink
This is an automated cherry-pick of pingcap#46303
Browse files Browse the repository at this point in the history
Signed-off-by: ti-chi-bot <[email protected]>
  • Loading branch information
AilinKid authored and ti-chi-bot committed Aug 31, 2023
1 parent 4cb399e commit 5416295
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 1 deletion.
155 changes: 155 additions & 0 deletions expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,152 @@ func pushNotAcrossArgs(ctx sessionctx.Context, exprs []Expression, not bool) ([]
return newExprs, flag
}

// todo: consider more no precision-loss downcast cases.
func noPrecisionLossCastCompatible(cast, argCol *types.FieldType) bool {
// now only consider varchar type and integer.
if !(types.IsTypeVarchar(cast.GetType()) && types.IsTypeVarchar(argCol.GetType())) &&
!(mysql.IsIntegerType(cast.GetType()) && mysql.IsIntegerType(argCol.GetType())) {
// varchar type and integer on the storage layer is quite same, while the char type has its padding suffix.
return false
}
if types.IsTypeVarchar(cast.GetType()) {
// cast varchar function only bear the flen extension.
if cast.GetFlen() < argCol.GetFlen() {
return false
}
if !collate.CompatibleCollate(cast.GetCollate(), argCol.GetCollate()) {
return false
}
} else {
// For integers, we should ignore the potential display length represented by flen, using the default flen of the type.
castFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(cast.GetType())
originFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(argCol.GetType())
// cast integer function only bear the flen extension and signed symbol unchanged.
if castFlen < originFlen {
return false
}
if mysql.HasUnsignedFlag(cast.GetFlag()) != mysql.HasUnsignedFlag(argCol.GetFlag()) {
return false
}
}
return true
}

func unwrapCast(sctx sessionctx.Context, parentF *ScalarFunction, castOffset int) (Expression, bool) {
_, collation := parentF.CharsetAndCollation()
cast, ok := parentF.GetArgs()[castOffset].(*ScalarFunction)
if !ok || cast.FuncName.L != ast.Cast {
return parentF, false
}
// eg: if (cast(A) EQ const) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range.
if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) {
return parentF, false
}
// 1-castOffset should be constant
if _, ok := parentF.GetArgs()[1-castOffset].(*Constant); !ok {
return parentF, false
}

// the direct args of cast function should be column.
c, ok := cast.GetArgs()[0].(*Column)
if !ok {
return parentF, false
}

// current only consider varchar and integer
if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) {
return parentF, false
}

// the column is covered by indexes, deconstructing it out.
if castOffset == 0 {
return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, c, parentF.GetArgs()[1]), true
}
return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, parentF.GetArgs()[0], c), true
}

// eliminateCastFunction will detect the original arg before and the cast type after, once upon
// there is no precision loss between them, current cast wrapper can be eliminated. For string
// type, collation is also taken into consideration. (mainly used to build range or point)
func eliminateCastFunction(sctx sessionctx.Context, expr Expression) (_ Expression, changed bool) {
f, ok := expr.(*ScalarFunction)
if !ok {
return expr, false
}
_, collation := expr.CharsetAndCollation()
switch f.FuncName.L {
case ast.LogicOr:
dnfItems := FlattenDNFConditions(f)
rmCast := false
rmCastItems := make([]Expression, len(dnfItems))
for i, dnfItem := range dnfItems {
newExpr, curDowncast := eliminateCastFunction(sctx, dnfItem)
rmCastItems[i] = newExpr
if curDowncast {
rmCast = true
}
}
if rmCast {
// compose the new DNF expression.
return ComposeDNFCondition(sctx, rmCastItems...), true
}
return expr, false
case ast.LogicAnd:
cnfItems := FlattenCNFConditions(f)
rmCast := false
rmCastItems := make([]Expression, len(cnfItems))
for i, cnfItem := range cnfItems {
newExpr, curDowncast := eliminateCastFunction(sctx, cnfItem)
rmCastItems[i] = newExpr
if curDowncast {
rmCast = true
}
}
if rmCast {
// compose the new CNF expression.
return ComposeCNFCondition(sctx, rmCastItems...), true
}
return expr, false
case ast.EQ, ast.NullEQ, ast.LE, ast.GE, ast.LT, ast.GT:
// for case: eq(cast(test.t2.a, varchar(100), "aaaaa"), once t2.a is covered by index or pk, try deconstructing it out.
if newF, ok := unwrapCast(sctx, f, 0); ok {
return newF, true
}
// for case: eq("aaaaa", cast(test.t2.a, varchar(100)), once t2.a is covered by index or pk, try deconstructing it out.
if newF, ok := unwrapCast(sctx, f, 1); ok {
return newF, true
}
case ast.In:
// case for: cast(a<int> as bigint) in (1,2,3), we could deconstruct column 'a out directly.
cast, ok := f.GetArgs()[0].(*ScalarFunction)
if !ok || cast.FuncName.L != ast.Cast {
return expr, false
}
// eg: if (cast(A) IN {const}) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range.
if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) {
return expr, false
}
for _, arg := range f.GetArgs()[1:] {
if _, ok := arg.(*Constant); !ok {
return expr, false
}
}
// the direct args of cast function should be column.
c, ok := cast.GetArgs()[0].(*Column)
if !ok {
return expr, false
}
// current only consider varchar and integer
if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) {
return expr, false
}
newArgs := []Expression{c}
newArgs = append(newArgs, f.GetArgs()[1:]...)
return NewFunctionInternal(sctx, f.FuncName.L, f.RetType, newArgs...), true
}
return expr, false
}

// pushNotAcrossExpr try to eliminate the NOT expr in expression tree.
// Input `not` indicates whether there's a `NOT` be pushed down.
// Output `changed` indicates whether the output expression differs from the
Expand Down Expand Up @@ -781,6 +927,15 @@ func PushDownNot(ctx sessionctx.Context, expr Expression) Expression {
return newExpr
}

// EliminateNoPrecisionLossCast remove the redundant cast function for range build convenience.
// 1: deeper cast embedded in other complicated function will not be considered.
// 2: cast args should be one for original base column and one for constant.
// 3: some collation compatibility and precision loss will be considered when remove this cast func.
func EliminateNoPrecisionLossCast(sctx sessionctx.Context, expr Expression) Expression {
newExpr, _ := eliminateCastFunction(sctx, expr)
return newExpr
}

// ContainOuterNot checks if there is an outer `not`.
func ContainOuterNot(expr Expression) bool {
return containOuterNot(expr, false)
Expand Down
49 changes: 49 additions & 0 deletions planner/core/casetest/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3310,6 +3310,55 @@ func TestTiFlashFineGrainedShuffle(t *testing.T) {
}
}

func TestDowncastPointGetOrRangeScan(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("create table t1 (a bigint key)")
tk.MustExec("create table t2 (a int key)")
tk.MustExec("create definer=`root`@`127.0.0.1` view v1 as (select a from t1) union (select a from t2)")
// select * from v where a = 1 will lead a condition: EQ(cast(t2.a as bigint), 1),
// we should downcast it, utilizing t2.a =1 to walking through the pk point-get. Because cast doesn't contain any precision loss.

tk.MustExec("create table t3 (a varchar(100) key)")
tk.MustExec("create table t4 (a varchar(10) key)")
tk.MustExec("create definer=`root`@`127.0.0.1` view v2 as (select a from t3) union (select a from t4)")
// select * from v2 where a = 'test' will lead a condition: EQ(cast(t2.a as varchar(100) same collation), 1),
// we should downcast it, utilizing t2.a = 'test' to walking through the pk point-get. Because cast doesn't contain any precision loss.

tk.MustExec("create table t5 (a char(100) key)")
tk.MustExec("create table t6 (a char(10) key)")
tk.MustExec("create definer=`root`@`127.0.0.1` view v3 as (select a from t5) union (select a from t6)")
// select * from v3 where a = 'test' will lead a condition: EQ(cast(t2.a as char(100) same collation), 1),
// for char type, it depends, with binary collate, the appended '0' after cast column a from char(10) to char(100) will make some difference
// on comparison on where a = 'test' before and after the UNION operator; so we didn't allow this kind of type downcast currently (precision diff).

tk.MustExec("create table t7 (a varchar(100) key)")
tk.MustExec("create table t8 (a int key)")
tk.MustExec("create definer=`root`@`127.0.0.1` view v4 as (select a from t7) union (select a from t8)")
// since UNION OP will unify the a(int) and a(varchar100) as varchar(100)
// select * from v4 where a = "test" will lead a condition: EQ(cast(t2.a as varchar(100)), "test"), and since
// cast int to varchar(100) may have some precision loss, we couldn't utilize a="test" to get the range directly.

var input []string
var output []struct {
SQL string
Plan []string
Result []string
}
integrationSuiteData := GetIntegrationSuiteData()
integrationSuiteData.LoadTestCases(t, &input, &output)
for i, tt := range input {
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery("explain format='brief' " + tt).Rows())
output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Sort().Rows())
})
tk.MustQuery("explain format='brief' " + tt).Check(testkit.Rows(output[i].Plan...))
tk.MustQuery(tt).Sort().Check(testkit.Rows(output[i].Result...))
}
}

func TestNullConditionForPrefixIndex(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
15 changes: 15 additions & 0 deletions planner/core/casetest/testdata/integration_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -1321,11 +1321,26 @@
]
},
{
<<<<<<< HEAD
"name": "TestFixControl",
"cases": [
"set @@tidb_opt_fix_control = \"1000:'on', 10000:1\"",
"set @@tidb_opt_fix_control = \"100:'on', 100:1\"",
"set @@tidb_opt_fix_control = \"100, 100\""
=======
"name": "TestDowncastPointGetOrRangeScan",
"cases": [
"select * from v1 where a = 1; -- the condition should be downcast through both side and go get point",
"select * from v1 where a = '1test'; -- the condition should be downcast through both side and go get point too",
"select * from v1 where a > 1; -- the condition should be downcast through both side and go range scan",
"select * from v2 where a = 'test';",
"select * from v2 where a = 1;",
"select * from v2 where a > 'test';",
"select * from v3 where a = 'test' -- the condition shouldn't be downcast through both side and go get point",
"select * from v3 where a > 'test' -- the condition shouldn't be downcast through both side and go get point too",
"select * from v4 where a = 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range",
"select * from v4 where a > 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range"
>>>>>>> 28a9c7f0fb7 (planner: fix cast(col) = range couldn't build range when cast function doesn't contain any precision loss in some cases (#46303))
]
}
]
132 changes: 132 additions & 0 deletions planner/core/casetest/testdata/integration_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -11180,6 +11180,7 @@
]
},
{
<<<<<<< HEAD
"Name": "TestFixControl",
"Cases": [
{
Expand Down Expand Up @@ -11227,6 +11228,137 @@
"Variable": [
"100:'on', 100:1"
]
=======
"Name": "TestDowncastPointGetOrRangeScan",
"Cases": [
{
"SQL": "select * from v1 where a = 1; -- the condition should be downcast through both side and go get point",
"Plan": [
"HashAgg 2.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 2.00 root ",
" ├─Point_Get 1.00 root table:t1 handle:1",
" └─Projection 1.00 root cast(test.t2.a, bigint(20) BINARY)->Column#3",
" └─Point_Get 1.00 root table:t2 handle:1"
],
"Result": null
},
{
"SQL": "select * from v1 where a = '1test'; -- the condition should be downcast through both side and go get point too",
"Plan": [
"HashAgg 2.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 2.00 root ",
" ├─Point_Get 1.00 root table:t1 handle:1",
" └─Projection 1.00 root cast(test.t2.a, bigint(20) BINARY)->Column#3",
" └─Point_Get 1.00 root table:t2 handle:1"
],
"Result": null
},
{
"SQL": "select * from v1 where a > 1; -- the condition should be downcast through both side and go range scan",
"Plan": [
"HashAgg 5333.33 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 6666.67 root ",
" ├─TableReader 3333.33 root data:TableRangeScan",
" │ └─TableRangeScan 3333.33 cop[tikv] table:t1 range:(1,+inf], keep order:false, stats:pseudo",
" └─Projection 3333.33 root cast(test.t2.a, bigint(20) BINARY)->Column#3",
" └─TableReader 3333.33 root data:TableRangeScan",
" └─TableRangeScan 3333.33 cop[tikv] table:t2 range:(1,+inf], keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v2 where a = 'test';",
"Plan": [
"HashAgg 16.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 20.00 root ",
" ├─Point_Get 1.00 root table:t3, clustered index:PRIMARY(a) ",
" └─Projection 10.00 root cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─Point_Get 1.00 root table:t4, clustered index:PRIMARY(a) "
],
"Result": null
},
{
"SQL": "select * from v2 where a = 1;",
"Plan": [
"HashAgg 12800.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 16000.00 root ",
" ├─TableReader 8000.00 root data:Selection",
" │ └─Selection 8000.00 cop[tikv] eq(cast(test.t3.a, double BINARY), 1)",
" │ └─TableFullScan 10000.00 cop[tikv] table:t3 keep order:false, stats:pseudo",
" └─Projection 8000.00 root cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 8000.00 root data:Selection",
" └─Selection 8000.00 cop[tikv] eq(cast(cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), double BINARY), 1)",
" └─TableFullScan 10000.00 cop[tikv] table:t4 keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v2 where a > 'test';",
"Plan": [
"HashAgg 5333.33 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 6666.67 root ",
" ├─TableReader 3333.33 root data:TableRangeScan",
" │ └─TableRangeScan 3333.33 cop[tikv] table:t3 range:(\"test\",+inf], keep order:false, stats:pseudo",
" └─Projection 3333.33 root cast(test.t4.a, varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 3333.33 root data:TableRangeScan",
" └─TableRangeScan 3333.33 cop[tikv] table:t4 range:(\"test\",+inf], keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v3 where a = 'test' -- the condition shouldn't be downcast through both side and go get point",
"Plan": [
"HashAgg 6408.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 8010.00 root ",
" ├─Point_Get 1.00 root table:t5, clustered index:PRIMARY(a) ",
" └─Projection 8000.00 root cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 8000.00 root data:Selection",
" └─Selection 8000.00 cop[tikv] eq(cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")",
" └─TableFullScan 10000.00 cop[tikv] table:t6 keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v3 where a > 'test' -- the condition shouldn't be downcast through both side and go get point too",
"Plan": [
"HashAgg 9066.67 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 11333.33 root ",
" ├─TableReader 3333.33 root data:TableRangeScan",
" │ └─TableRangeScan 3333.33 cop[tikv] table:t5 range:(\"test\",+inf], keep order:false, stats:pseudo",
" └─Projection 8000.00 root cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 8000.00 root data:Selection",
" └─Selection 8000.00 cop[tikv] gt(cast(test.t6.a, char(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")",
" └─TableFullScan 10000.00 cop[tikv] table:t6 keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v4 where a = 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range",
"Plan": [
"HashAgg 6408.00 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 8010.00 root ",
" ├─Point_Get 1.00 root table:t7, clustered index:PRIMARY(a) ",
" └─Projection 8000.00 root cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 8000.00 root data:Selection",
" └─Selection 8000.00 cop[tikv] eq(cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")",
" └─TableFullScan 10000.00 cop[tikv] table:t8 keep order:false, stats:pseudo"
],
"Result": null
},
{
"SQL": "select * from v4 where a > 'test' -- diff column union may have precision loss couldn't downcast the condition to get the range",
"Plan": [
"HashAgg 9066.67 root group by:Column#3, funcs:firstrow(Column#3)->Column#3",
"└─Union 11333.33 root ",
" ├─TableReader 3333.33 root data:TableRangeScan",
" │ └─TableRangeScan 3333.33 cop[tikv] table:t7 range:(\"test\",+inf], keep order:false, stats:pseudo",
" └─Projection 8000.00 root cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin)->Column#3",
" └─TableReader 8000.00 root data:Selection",
" └─Selection 8000.00 cop[tikv] gt(cast(test.t8.a, varchar(100) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin), \"test\")",
" └─TableFullScan 10000.00 cop[tikv] table:t8 keep order:false, stats:pseudo"
],
"Result": null
>>>>>>> 28a9c7f0fb7 (planner: fix cast(col) = range couldn't build range when cast function doesn't contain any precision loss in some cases (#46303))
}
]
}
Expand Down
Loading

0 comments on commit 5416295

Please sign in to comment.