Skip to content

Commit

Permalink
Merge branch 'release-4.0' into release-4.0-f3c64ceecefd
Browse files Browse the repository at this point in the history
  • Loading branch information
guo-shaoge committed Mar 16, 2021
2 parents ead41d9 + a6d8bcb commit f7f00ed
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
3 changes: 3 additions & 0 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}

fieldTp := types.AggFieldType(fieldTps)
// Here we turn off NotNullFlag. Because if all when-clauses are false,
// the result of case-when expr is NULL.
types.SetTypeFlag(&fieldTp.Flag, mysql.NotNullFlag, false)
tp := fieldTp.EvalType()

if tp == types.ETInt {
Expand Down
10 changes: 10 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2920,6 +2920,16 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) {
result.Check(testkit.Rows("<nil> 4"))
result = tk.MustQuery("select * from t where b = case when a is null then 4 when a = 'str5' then 7 else 9 end")
result.Check(testkit.Rows("<nil> 4"))

// return type of case when expr should not include NotNullFlag. issue-23036
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(c1 int not null)")
tk.MustExec("insert into t1 values(1)")
result = tk.MustQuery("select (case when null then c1 end) is null from t1")
result.Check(testkit.Rows("1"))
result = tk.MustQuery("select (case when null then c1 end) is not null from t1")
result.Check(testkit.Rows("0"))

// test warnings
tk.MustQuery("select case when b=0 then 1 else 1/b end from t")
tk.MustQuery("show warnings").Check(testkit.Rows())
Expand Down
7 changes: 4 additions & 3 deletions types/field_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ func AggregateEvalType(fts []*FieldType, flag *uint) EvalType {
}
lft = rft
}
setTypeFlag(flag, mysql.UnsignedFlag, unsigned)
setTypeFlag(flag, mysql.BinaryFlag, !aggregatedEvalType.IsStringKind() || gotBinString)
SetTypeFlag(flag, mysql.UnsignedFlag, unsigned)
SetTypeFlag(flag, mysql.BinaryFlag, !aggregatedEvalType.IsStringKind() || gotBinString)
return aggregatedEvalType
}

Expand All @@ -159,7 +159,8 @@ func mergeEvalType(lhs, rhs EvalType, lft, rft *FieldType, isLHSUnsigned, isRHSU
return ETInt
}

func setTypeFlag(flag *uint, flagItem uint, on bool) {
// SetTypeFlag turns the flagItem on or off.
func SetTypeFlag(flag *uint, flagItem uint, on bool) {
if on {
*flag |= flagItem
} else {
Expand Down

0 comments on commit f7f00ed

Please sign in to comment.