Skip to content

Commit

Permalink
planner/core: add DefaultExpr support for expressionRewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
bb7133 committed Dec 2, 2018
1 parent 36bcf5d commit 2cdd49f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
48 changes: 48 additions & 0 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -169,6 +170,7 @@ type expressionRewriter struct {
insertPlan *Insert
}

// constructBinaryOpFunction converts binary operator functions
// 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2)
// 2. Else constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to
// `IF( a0 NE b0, a0 op b0,
Expand Down Expand Up @@ -798,6 +800,8 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
er.isNullToExpression(v)
case *ast.IsTruthExpr:
er.isTrueToScalarFunc(v)
case *ast.DefaultExpr:
er.evalDefaultExpr(v)
default:
er.err = errors.Errorf("UnknownType: %T", v)
return retNode, false
Expand Down Expand Up @@ -1276,3 +1280,47 @@ func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
}
er.err = ErrUnknownColumn.GenWithStackByArgs(v.String(), clauseMsg[er.b.curClause])
}

func (er *expressionRewriter) evalDefaultExpr(v *ast.DefaultExpr) {
col, err := er.schema.FindColumn(v.Name)
if err != nil {
er.err = errors.Trace(err)
return
}

table, err := er.b.is.TableByName(col.DBName, col.OrigTblName)
if err != nil {
er.err = errors.Trace(err)
return
}

var val *expression.Constant
for _, col := range table.Cols() {
if col.Name.L == v.Name.Name.L {
// if column default value is 'current_timestamp', use NULL to be compatible with MySQL 5.7
if hasCurrentTimestampDefault(col) {
val = expression.Null
} else {
val, er.err = er.b.getDefaultValue(col)
}
break
}
}
if er.err != nil {
return
}
stkLen := len(er.ctxStack)
er.ctxStack = er.ctxStack[:stkLen-1]
er.ctxStack = append(er.ctxStack, val)
}

func hasCurrentTimestampDefault(col *table.Column) bool {
if col.Tp != mysql.TypeTimestamp && col.Tp != mysql.TypeDatetime {
return false
}
x, ok := col.DefaultValue.(string)
if !ok {
return false
}
return strings.ToUpper(x) == strings.ToUpper(ast.CurrentTimestamp)
}
44 changes: 44 additions & 0 deletions planner/core/expression_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/testutil"
)

var _ = Suite(&testExpressionRewriterSuite{})
Expand Down Expand Up @@ -58,3 +59,46 @@ func (s *testExpressionRewriterSuite) TestBinaryOpFunction(c *C) {
tk.MustQuery("SELECT * FROM t WHERE (a,b,c) <= (1,2,3) order by b").Check(testkit.Rows("1 1 <nil>", "1 2 3"))
tk.MustQuery("SELECT * FROM t WHERE (a,b,c) > (1,2,3) order by b").Check(testkit.Rows("1 3 <nil>"))
}

func (s *testExpressionRewriterSuite) TestDefaultFunction(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()
tk.MustExec("use test")
tk.MustExec("drop table if exists t1")
tk.MustExec(`create table t1(
a varchar(10) default 'def',
b varchar(10),
c int default '10',
d double default '3.14',
e datetime default '20180101',
f datetime default current_timestamp);`)
tk.MustExec("insert into t1(a, b, c, d) values ('1', '1', 1, 1)")
tk.MustQuery(`select
default(a) as defa,
default(b) as defb,
default(c) as defc,
default(d) as defd,
default(e) as defe,
default(f) as deff
from t1`).Check(testutil.RowsWithSep("|", "def|<nil>|10|3.14|2018-01-01 00:00:00|<nil>"))

tk.MustExec("create table t2(a varchar(10), b varchar(10))")
tk.MustExec("insert into t2 values ('1', '1')")
_, err = tk.Exec("select default(a) from t1, t2")
c.Assert(err, NotNil)
tk.MustQuery("select default(t1.a) from t1, t2").Check(testkit.Rows("def"))

tk.MustExec("prepare stmt from 'select default(a) from t1';")
tk.MustQuery("execute stmt").Check(testkit.Rows("def"))
tk.MustExec("alter table t1 modify a varchar(10) default 'DEF'")
tk.MustQuery("execute stmt").Check(testkit.Rows("DEF"))

tk.MustExec("update t1 set c = c + default(c)")
tk.MustQuery("select c from t1").Check(testkit.Rows("11"))
}

0 comments on commit 2cdd49f

Please sign in to comment.