Skip to content

Commit

Permalink
expression: support JSON return type in case expression (#8355)
Browse files Browse the repository at this point in the history
  • Loading branch information
eurekaka committed Nov 21, 2018
1 parent 6fb260f commit 6bed56d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 0 deletions.
37 changes: 37 additions & 0 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
case types.ETDuration:
sig = &builtinCaseWhenDurationSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CaseWhenDuration)
case types.ETJson:
sig = &builtinCaseWhenJSONSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CaseWhenJson)
}
return sig, nil
}
Expand Down Expand Up @@ -425,6 +428,40 @@ func (b *builtinCaseWhenDurationSig) evalDuration(row chunk.Row) (ret types.Dura
return ret, true, nil
}

type builtinCaseWhenJSONSig struct {
baseBuiltinFunc
}

func (b *builtinCaseWhenJSONSig) Clone() builtinFunc {
newSig := &builtinCaseWhenJSONSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalJSON evals a builtinCaseWhenJSONSig.
// See https://dev.mysql.com/doc/refman/5.7/en/case.html
func (b *builtinCaseWhenJSONSig) evalJSON(row chunk.Row) (ret json.BinaryJSON, isNull bool, err error) {
var condition int64
args, l := b.getArgs(), len(b.getArgs())
for i := 0; i < l-1; i += 2 {
condition, isNull, err = args[i].EvalInt(b.ctx, row)
if err != nil {
return
}
if isNull || condition == 0 {
continue
}
return args[i+1].EvalJSON(b.ctx, row)
}
// when clause(condition, result) -> args[i], args[i+1]; (i >= 0 && i+1 < l-1)
// else clause -> args[l-1]
// If case clause has else clause, l%2 == 1.
if l%2 == 1 {
return args[l-1].EvalJSON(b.ctx, row)
}
return ret, true, nil
}

type ifFunctionClass struct {
baseFunctionClass
}
Expand Down
2 changes: 2 additions & 0 deletions expression/builtin_control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ func (s *testEvaluatorSuite) TestCaseWhen(c *C) {
{[]interface{}{nil, 1, nil, 2, 3}, 3},
{[]interface{}{false, 1, nil, 2, 3}, 3},
{[]interface{}{nil, 1, false, 2, 3}, 3},
{[]interface{}{1, jsonInt.GetMysqlJSON(), nil}, 3},
{[]interface{}{0, jsonInt.GetMysqlJSON(), nil}, nil},
}
fc := funcs[ast.Case]
for _, t := range tbl {
Expand Down
2 changes: 2 additions & 0 deletions expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
case tipb.ScalarFuncSig_CoalesceInt:
f = &builtinCoalesceIntSig{base}

case tipb.ScalarFuncSig_CaseWhenJson:
f = &builtinCaseWhenJSONSig{base}
case tipb.ScalarFuncSig_CaseWhenDecimal:
f = &builtinCaseWhenDecimalSig{base}
case tipb.ScalarFuncSig_CaseWhenDuration:
Expand Down
2 changes: 2 additions & 0 deletions expression/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ func (s *testEvaluatorSuite) kindToFieldType(kind byte) types.FieldType {
ft.Collate = charset.CollationBin
case types.KindMysqlBit:
ft.Tp = mysql.TypeBit
case types.KindMysqlJSON:
ft.Tp = mysql.TypeJSON
}
return ft
}
Expand Down

0 comments on commit 6bed56d

Please sign in to comment.