Skip to content

Commit

Permalink
expression: fix the return type of coalesce when arg type is DATE (pi…
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhangxian1008 authored Sep 11, 2024
1 parent f9c1dd1 commit 87569e8
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 130 deletions.
72 changes: 10 additions & 62 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,19 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

fieldTps := make([]*types.FieldType, 0, len(args))
flag := uint(0)
for _, arg := range args {
fieldTps = append(fieldTps, arg.GetType())
flag |= arg.GetType().GetFlag() & mysql.NotNullFlag
}

// Use the aggregated field type as retType.
resultFieldType := types.AggFieldType(fieldTps)
var tempType uint
resultEvalType := types.AggregateEvalType(fieldTps, &tempType)
resultFieldType.SetFlag(tempType)
retEvalTp := resultFieldType.EvalType()
resultFieldType, err := InferType4ControlFuncs(ctx, c.funcName, args...)
if err != nil {
return nil, err
}

resultFieldType.AddFlag(flag)

retEvalTp := resultFieldType.EvalType()
fieldEvalTps := make([]types.EvalType, 0, len(args))
for range args {
fieldEvalTps = append(fieldEvalTps, retEvalTp)
Expand All @@ -140,60 +141,7 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

bf.tp.AddFlag(resultFieldType.GetFlag())
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(types.UnspecifiedLength)

// Set retType to BINARY(0) if all arguments are of type NULL.
if resultFieldType.GetType() == mysql.TypeNull {
types.SetBinChsClnFlag(bf.tp)
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(0)
} else {
maxIntLen := 0
maxFlen := 0

// Find the max length of field in `maxFlen`,
// and max integer-part length in `maxIntLen`.
for _, argTp := range fieldTps {
if argTp.GetDecimal() > resultFieldType.GetDecimal() {
resultFieldType.SetDecimalUnderLimit(argTp.GetDecimal())
}
argIntLen := argTp.GetFlen()
if argTp.GetDecimal() > 0 {
argIntLen -= argTp.GetDecimal() + 1
}

// Reduce the sign bit if it is a signed integer/decimal
if !mysql.HasUnsignedFlag(argTp.GetFlag()) {
argIntLen--
}
if argIntLen > maxIntLen {
maxIntLen = argIntLen
}
if argTp.GetFlen() > maxFlen || argTp.GetFlen() == types.UnspecifiedLength {
maxFlen = argTp.GetFlen()
}
}
// For integer, field length = maxIntLen + (1/0 for sign bit)
// For decimal, field length = maxIntLen + maxDecimal + (1/0 for sign bit)
if resultEvalType == types.ETInt || resultEvalType == types.ETDecimal {
resultFieldType.SetFlenUnderLimit(maxIntLen + resultFieldType.GetDecimal())
if resultFieldType.GetDecimal() > 0 {
resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1)
}
if !mysql.HasUnsignedFlag(resultFieldType.GetFlag()) {
resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1)
}
bf.tp = resultFieldType
} else {
bf.tp.SetFlen(maxFlen)
}
// Set the field length to maxFlen for other types.
if bf.tp.GetFlen() > mysql.MaxDecimalWidth {
bf.tp.SetFlen(mysql.MaxDecimalWidth)
}
}
bf.tp = resultFieldType

switch retEvalTp {
case types.ETInt:
Expand Down
9 changes: 9 additions & 0 deletions expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,12 @@ func TestRefineArgsWithCastEnum(t *testing.T) {
require.Equal(t, zeroUintConst, args[0])
require.Equal(t, enumCol, args[1])
}

func TestIssue46475(t *testing.T) {
ctx := createContext(t)
args := []interface{}{nil, dt, nil}

f, err := newFunctionForTest(ctx, ast.Coalesce, primitiveValsToConstants(ctx, args)...)
require.NoError(t, err)
require.Equal(t, f.GetType().GetType(), mysql.TypeDate)
}
Loading

0 comments on commit 87569e8

Please sign in to comment.