diff --git a/go/vt/vttablet/tabletmanager/vreplication/player_plan.go b/go/vt/vttablet/tabletmanager/vreplication/player_plan.go index ab960c90a2b..dc7bebf084e 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/player_plan.go +++ b/go/vt/vttablet/tabletmanager/vreplication/player_plan.go @@ -17,9 +17,6 @@ limitations under the License. package vreplication import ( - "fmt" - "strings" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" @@ -34,386 +31,85 @@ type PlayerPlan struct { } // TablePlan is the execution plan for a table within a player stream. -// There are two incarantions of this per table. The first one is built -// while analyzing the inital stream request. A tentative plan is built -// without knowing the table info. The second incarnation is built when -// we receive the field info for a table. At that time, we copy the -// original TablePlan into a separate map and populate the Fields and -// PKCols members. type TablePlan struct { - Name string - ColExprs []*ColExpr `json:",omitempty"` - OnInsert InsertType `json:",omitempty"` - - Fields []*querypb.Field `json:",omitempty"` - PKCols []*ColExpr `json:",omitempty"` -} - -// ColExpr describes the processing to be performed to -// compute the value of the target table column. -type ColExpr struct { - ColName sqlparser.ColIdent - ColNum int - Operation Operation `json:",omitempty"` - IsGrouped bool `json:",omitempty"` + Name string + PKReferences []string `json:",omitempty"` + Insert *sqlparser.ParsedQuery `json:",omitempty"` + Update *sqlparser.ParsedQuery `json:",omitempty"` + Delete *sqlparser.ParsedQuery `json:",omitempty"` + Fields []*querypb.Field `json:",omitempty"` } -// Operation is the opcode for the ColExpr. -type Operation int - -// The following values are the various ColExpr opcodes. -const ( - OpNone = Operation(iota) - OpCount - OpSum -) - -// InsertType describes the type of insert statement to generate. -type InsertType int - -// The following values are the various insert types. -const ( - InsertNormal = InsertType(iota) - InsertOndup - InsertIgnore -) - -func buildPlayerPlan(filter *binlogdatapb.Filter) (*PlayerPlan, error) { - plan := &PlayerPlan{ - VStreamFilter: &binlogdatapb.Filter{ - Rules: make([]*binlogdatapb.Rule, len(filter.Rules)), - }, - TablePlans: make(map[string]*TablePlan), - } - for i, rule := range filter.Rules { - if strings.HasPrefix(rule.Match, "/") { - plan.VStreamFilter.Rules[i] = rule - continue - } - sendRule, tp, err := buildTablePlan(rule) - if err != nil { - return nil, err +func (tp *TablePlan) generateStatements(rowChange *binlogdatapb.RowChange) ([]string, error) { + // MakeRowTrusted is needed here because Proto3ToResult is not convenient. + var before, after bool + bindvars := make(map[string]*querypb.BindVariable) + if rowChange.Before != nil { + before = true + vals := sqltypes.MakeRowTrusted(tp.Fields, rowChange.Before) + for i, field := range tp.Fields { + bindvars["b_"+field.Name] = sqltypes.ValueBindVariable(vals[i]) } - plan.VStreamFilter.Rules[i] = sendRule - plan.TablePlans[sendRule.Match] = tp - } - return plan, nil -} - -func buildTablePlan(rule *binlogdatapb.Rule) (*binlogdatapb.Rule, *TablePlan, error) { - statement, err := sqlparser.Parse(rule.Filter) - if err != nil { - return nil, nil, err } - sel, ok := statement.(*sqlparser.Select) - if !ok { - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(statement)) - } - if sel.Distinct != "" { - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(sel)) - } - if len(sel.From) > 1 { - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(sel)) - } - node, ok := sel.From[0].(*sqlparser.AliasedTableExpr) - if !ok { - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(sel)) - } - fromTable := sqlparser.GetTableName(node.Expr) - if fromTable.IsEmpty() { - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(sel)) - } - - if _, ok := sel.SelectExprs[0].(*sqlparser.StarExpr); ok { - if len(sel.SelectExprs) != 1 { - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(sel)) - } - sendRule := &binlogdatapb.Rule{ - Match: fromTable.String(), - Filter: rule.Filter, + if rowChange.After != nil { + after = true + vals := sqltypes.MakeRowTrusted(tp.Fields, rowChange.After) + for i, field := range tp.Fields { + bindvars["a_"+field.Name] = sqltypes.ValueBindVariable(vals[i]) } - return sendRule, &TablePlan{Name: rule.Match}, nil - } - - tp := &TablePlan{ - Name: rule.Match, } - sendSelect := &sqlparser.Select{ - From: sel.From, - Where: sel.Where, - } - for _, expr := range sel.SelectExprs { - selExpr, cExpr, err := analyzeExpr(expr) + switch { + case !before && after: + query, err := tp.Insert.GenerateQuery(bindvars, nil) if err != nil { - return nil, nil, err + return nil, err } - if selExpr != nil { - sendSelect.SelectExprs = append(sendSelect.SelectExprs, selExpr) - cExpr.ColNum = len(sendSelect.SelectExprs) - 1 + return []string{query}, nil + case before && !after: + if tp.Delete == nil { + return nil, nil } - tp.ColExprs = append(tp.ColExprs, cExpr) - } - - if sel.GroupBy != nil { - if err := analyzeGroupBy(sel.GroupBy, tp); err != nil { - return nil, nil, err + query, err := tp.Delete.GenerateQuery(bindvars, nil) + if err != nil { + return nil, err } - tp.OnInsert = InsertIgnore - for _, cExpr := range tp.ColExprs { - if !cExpr.IsGrouped { - tp.OnInsert = InsertOndup - break + return []string{query}, nil + case before && after: + if !tp.pkChanged(bindvars) { + query, err := tp.Update.GenerateQuery(bindvars, nil) + if err != nil { + return nil, err } + return []string{query}, nil } - } - sendRule := &binlogdatapb.Rule{ - Match: fromTable.String(), - Filter: sqlparser.String(sendSelect), - } - return sendRule, tp, nil -} -func analyzeExpr(selExpr sqlparser.SelectExpr) (sqlparser.SelectExpr, *ColExpr, error) { - aliased, ok := selExpr.(*sqlparser.AliasedExpr) - if !ok { - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(selExpr)) - } - as := aliased.As - if as.IsEmpty() { - as = sqlparser.NewColIdent(sqlparser.String(aliased.Expr)) - } - switch expr := aliased.Expr.(type) { - case *sqlparser.ColName: - return selExpr, &ColExpr{ColName: as}, nil - case *sqlparser.FuncExpr: - if expr.Distinct || len(expr.Exprs) != 1 { - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) - } - if aliased.As.IsEmpty() { - return nil, nil, fmt.Errorf("expression needs an alias: %v", sqlparser.String(expr)) - } - switch fname := expr.Name.Lowered(); fname { - case "month", "day", "hour": - return selExpr, &ColExpr{ColName: as}, nil - case "count": - if _, ok := expr.Exprs[0].(*sqlparser.StarExpr); !ok { - return nil, nil, fmt.Errorf("only count(*) is supported: %v", sqlparser.String(expr)) - } - return nil, &ColExpr{ColName: as, Operation: OpCount}, nil - case "sum": - aInner, ok := expr.Exprs[0].(*sqlparser.AliasedExpr) - if !ok { - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) + queries := make([]string, 0, 2) + if tp.Delete != nil { + query, err := tp.Delete.GenerateQuery(bindvars, nil) + if err != nil { + return nil, err } - innerCol, ok := aInner.Expr.(*sqlparser.ColName) - if !ok { - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) - } - return &sqlparser.AliasedExpr{Expr: innerCol}, &ColExpr{ColName: as, Operation: OpSum}, nil - default: - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) - } - default: - return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) - } -} - -func analyzeGroupBy(groupBy sqlparser.GroupBy, tp *TablePlan) error { - for _, expr := range groupBy { - colname, ok := expr.(*sqlparser.ColName) - if !ok { - return fmt.Errorf("unexpected: %v", sqlparser.String(expr)) - } - cExpr := tp.FindCol(colname.Name) - if cExpr == nil { - return fmt.Errorf("group by expression does not reference an alias in the select list: %v", sqlparser.String(expr)) + queries = append(queries, query) } - if cExpr.Operation != OpNone { - return fmt.Errorf("group by expression is not allowed to reference an aggregate expression: %v", sqlparser.String(expr)) - } - cExpr.IsGrouped = true - } - return nil -} - -//-------------------------------------------------------------- -// TablePlan support functions. - -// FindCol finds the ColExpr. It returns nil if not found. -func (tp *TablePlan) FindCol(name sqlparser.ColIdent) *ColExpr { - for _, cExpr := range tp.ColExprs { - if cExpr.ColName.Equal(name) { - return cExpr - } - } - return nil -} - -// GenerateStatements must be called only after Fields and PKCols have been populated. -func (tp *TablePlan) GenerateStatements(rowChange *binlogdatapb.RowChange) []string { - // MakeRowTrusted is needed here because Proto3ToResult is not convenient. - var before, after []sqltypes.Value - if rowChange.Before != nil { - before = sqltypes.MakeRowTrusted(tp.Fields, rowChange.Before) - } - if rowChange.After != nil { - after = sqltypes.MakeRowTrusted(tp.Fields, rowChange.After) - } - var query string - switch { - case before == nil && after != nil: - query = tp.generateInsert(after) - case before != nil && after != nil: - pkChanged := false - for _, cExpr := range tp.PKCols { - if !valsEqual(before[cExpr.ColNum], after[cExpr.ColNum]) { - pkChanged = true - break - } - } - if pkChanged { - queries := make([]string, 0, 2) - if query := tp.generateDelete(before); query != "" { - queries = append(queries, query) - } - if query := tp.generateInsert(after); query != "" { - queries = append(queries, query) - } - return queries - } - query = tp.generateUpdate(before, after) - case before != nil && after == nil: - query = tp.generateDelete(before) - case before == nil && after == nil: - // unreachable - } - if query != "" { - return []string{query} - } - return nil -} - -func (tp *TablePlan) generateInsert(after []sqltypes.Value) string { - sql := sqlparser.NewTrackedBuffer(nil) - if tp.OnInsert == InsertIgnore { - sql.Myprintf("insert ignore into %v set ", sqlparser.NewTableIdent(tp.Name)) - } else { - sql.Myprintf("insert into %v set ", sqlparser.NewTableIdent(tp.Name)) - } - tp.generateInsertValues(sql, after) - if tp.OnInsert == InsertOndup { - sql.Myprintf(" on duplicate key update ") - _ = tp.generateUpdateValues(sql, nil, after) - } - return sql.String() -} - -func (tp *TablePlan) generateUpdate(before, after []sqltypes.Value) string { - if tp.OnInsert == InsertIgnore { - return tp.generateInsert(after) - } - sql := sqlparser.NewTrackedBuffer(nil) - sql.Myprintf("update %v set ", sqlparser.NewTableIdent(tp.Name)) - if ok := tp.generateUpdateValues(sql, before, after); !ok { - return "" - } - sql.Myprintf(" where ") - tp.generateWhereValues(sql, before) - return sql.String() -} - -func (tp *TablePlan) generateDelete(before []sqltypes.Value) string { - sql := sqlparser.NewTrackedBuffer(nil) - switch tp.OnInsert { - case InsertOndup: - return tp.generateUpdate(before, nil) - case InsertIgnore: - return "" - default: // insertNormal - sql.Myprintf("delete from %v where ", sqlparser.NewTableIdent(tp.Name)) - tp.generateWhereValues(sql, before) - } - return sql.String() -} - -func (tp *TablePlan) generateInsertValues(sql *sqlparser.TrackedBuffer, after []sqltypes.Value) { - separator := "" - for _, cExpr := range tp.ColExprs { - sql.Myprintf("%s%v=", separator, cExpr.ColName) - separator = ", " - if cExpr.Operation == OpCount { - sql.WriteString("1") - } else { - if cExpr.Operation == OpSum && after[cExpr.ColNum].IsNull() { - sql.WriteString("0") - } else { - encodeValue(sql, after[cExpr.ColNum]) - } + query, err := tp.Insert.GenerateQuery(bindvars, nil) + if err != nil { + return nil, err } + queries = append(queries, query) + return queries, nil } + return nil, nil } -// generateUpdateValues returns true if at least one value was set. Otherwise, it returns false. -func (tp *TablePlan) generateUpdateValues(sql *sqlparser.TrackedBuffer, before, after []sqltypes.Value) bool { - separator := "" - hasSet := false - for _, cExpr := range tp.ColExprs { - if cExpr.IsGrouped { - continue - } - if len(before) != 0 && len(after) != 0 { - if cExpr.Operation == OpCount { - continue - } - if valsEqual(before[cExpr.ColNum], after[cExpr.ColNum]) { - continue - } - } - sql.Myprintf("%s%v=", separator, cExpr.ColName) - separator = ", " - hasSet = true - if cExpr.Operation == OpCount || cExpr.Operation == OpSum { - sql.Myprintf("%v", cExpr.ColName) - } - if len(before) != 0 { - switch cExpr.Operation { - case OpNone: - if len(after) == 0 { - sql.WriteString("NULL") - } - case OpCount: - sql.WriteString("-1") - case OpSum: - if !before[cExpr.ColNum].IsNull() { - sql.WriteString("-") - encodeValue(sql, before[cExpr.ColNum]) - } - } - } - if len(after) != 0 { - switch cExpr.Operation { - case OpNone: - encodeValue(sql, after[cExpr.ColNum]) - case OpCount: - sql.WriteString("+1") - case OpSum: - if !after[cExpr.ColNum].IsNull() { - sql.WriteString("+") - encodeValue(sql, after[cExpr.ColNum]) - } - } +func (tp *TablePlan) pkChanged(bindvars map[string]*querypb.BindVariable) bool { + for _, pkref := range tp.PKReferences { + v1, _ := sqltypes.BindVariableToValue(bindvars["b_"+pkref]) + v2, _ := sqltypes.BindVariableToValue(bindvars["a_"+pkref]) + if !valsEqual(v1, v2) { + return true } } - return hasSet -} - -func (tp *TablePlan) generateWhereValues(sql *sqlparser.TrackedBuffer, before []sqltypes.Value) { - separator := "" - for _, cExpr := range tp.PKCols { - sql.Myprintf("%s%v=", separator, cExpr.ColName) - separator = " and " - encodeValue(sql, before[cExpr.ColNum]) - } + return false } func valsEqual(v1, v2 sqltypes.Value) bool { diff --git a/go/vt/vttablet/tabletmanager/vreplication/player_plan_test.go b/go/vt/vttablet/tabletmanager/vreplication/player_plan_test.go index bf5066002f2..2baa5cc9599 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/player_plan_test.go +++ b/go/vt/vttablet/tabletmanager/vreplication/player_plan_test.go @@ -21,13 +21,25 @@ import ( "testing" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" - "vitess.io/vitess/go/vt/sqlparser" ) -func TestPlayerPlan(t *testing.T) { +type TestPlayerPlan struct { + VStreamFilter *binlogdatapb.Filter + TablePlans map[string]*TestTablePlan +} + +type TestTablePlan struct { + Name string + PKReferences []string `json:",omitempty"` + Insert string `json:",omitempty"` + Update string `json:",omitempty"` + Delete string `json:",omitempty"` +} + +func TestBuildPlayerPlan(t *testing.T) { testcases := []struct { input *binlogdatapb.Filter - plan *PlayerPlan + plan *TestPlayerPlan err string }{{ // Regular expression @@ -36,13 +48,13 @@ func TestPlayerPlan(t *testing.T) { Match: "/.*", }}, }, - plan: &PlayerPlan{ + plan: &TestPlayerPlan{ VStreamFilter: &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ Match: "/.*", }}, }, - TablePlans: map[string]*TablePlan{}, + TablePlans: map[string]*TestTablePlan{}, }, }, { // '*' expression @@ -52,14 +64,14 @@ func TestPlayerPlan(t *testing.T) { Filter: "select * from t2", }}, }, - plan: &PlayerPlan{ + plan: &TestPlayerPlan{ VStreamFilter: &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ Match: "t2", Filter: "select * from t2", }}, }, - TablePlans: map[string]*TablePlan{ + TablePlans: map[string]*TestTablePlan{ "t2": { Name: "t1", }, @@ -73,183 +85,117 @@ func TestPlayerPlan(t *testing.T) { Filter: "select c1, c2 from t2", }}, }, - plan: &PlayerPlan{ + plan: &TestPlayerPlan{ VStreamFilter: &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ Match: "t2", Filter: "select c1, c2 from t2", }}, }, - TablePlans: map[string]*TablePlan{ + TablePlans: map[string]*TestTablePlan{ "t2": { - Name: "t1", - ColExprs: []*ColExpr{{ - ColName: sqlparser.NewColIdent("c1"), - ColNum: 0, - }, { - ColName: sqlparser.NewColIdent("c2"), - ColNum: 1, - }}, + Name: "t1", + PKReferences: []string{"c1"}, + Insert: "insert into t1 set c1=:a_c1, c2=:a_c2", + Update: "update t1 set c2=:a_c2 where c1=:b_c1", + Delete: "delete from t1 where c1=:b_c1", }, }, }, }, { - // func expr + // partial group by input: &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ Match: "t1", - Filter: "select hour(c1) as hc1, day(c2) as dc2 from t2", + Filter: "select c1, c2, c3 from t2 group by c3, c1", }}, }, - plan: &PlayerPlan{ + plan: &TestPlayerPlan{ VStreamFilter: &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ Match: "t2", - Filter: "select hour(c1) as hc1, day(c2) as dc2 from t2", + Filter: "select c1, c2, c3 from t2", }}, }, - TablePlans: map[string]*TablePlan{ + TablePlans: map[string]*TestTablePlan{ "t2": { - Name: "t1", - ColExprs: []*ColExpr{{ - ColName: sqlparser.NewColIdent("hc1"), - ColNum: 0, - }, { - ColName: sqlparser.NewColIdent("dc2"), - ColNum: 1, - }}, + Name: "t1", + PKReferences: []string{"c1"}, + Insert: "insert into t1 set c1=:a_c1, c2=:a_c2, c3=:a_c3 on duplicate key update c2=:a_c2", + Update: "update t1 set c2=:a_c2 where c1=:b_c1", + Delete: "update t1 set c2=null where c1=:b_c1", }, }, }, }, { - // count expr + // full group by input: &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ Match: "t1", - Filter: "select hour(c1) as hc1, count(*) as c, day(c2) as dc2 from t2", + Filter: "select c1, c2, c3 from t2 group by c3, c1, c2", }}, }, - plan: &PlayerPlan{ + plan: &TestPlayerPlan{ VStreamFilter: &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ Match: "t2", - Filter: "select hour(c1) as hc1, day(c2) as dc2 from t2", + Filter: "select c1, c2, c3 from t2", }}, }, - TablePlans: map[string]*TablePlan{ + TablePlans: map[string]*TestTablePlan{ "t2": { - Name: "t1", - ColExprs: []*ColExpr{{ - ColName: sqlparser.NewColIdent("hc1"), - ColNum: 0, - }, { - ColName: sqlparser.NewColIdent("c"), - Operation: OpCount, - }, { - ColName: sqlparser.NewColIdent("dc2"), - ColNum: 1, - }}, + Name: "t1", + PKReferences: []string{"c1"}, + Insert: "insert ignore into t1 set c1=:a_c1, c2=:a_c2, c3=:a_c3", + Update: "insert ignore into t1 set c1=:a_c1, c2=:a_c2, c3=:a_c3", }, }, }, }, { - // sum expr input: &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ Match: "t1", - Filter: "select hour(c1) as hc1, sum(c3) as s, day(c2) as dc2 from t2", + Filter: "select foo(a) as c1, b c2 from t1", }}, }, - plan: &PlayerPlan{ + plan: &TestPlayerPlan{ VStreamFilter: &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ - Match: "t2", - Filter: "select hour(c1) as hc1, c3, day(c2) as dc2 from t2", + Match: "t1", + Filter: "select a, b from t1", }}, }, - TablePlans: map[string]*TablePlan{ - "t2": { - Name: "t1", - ColExprs: []*ColExpr{{ - ColName: sqlparser.NewColIdent("hc1"), - ColNum: 0, - }, { - ColName: sqlparser.NewColIdent("s"), - ColNum: 1, - Operation: OpSum, - }, { - ColName: sqlparser.NewColIdent("dc2"), - ColNum: 2, - }}, + TablePlans: map[string]*TestTablePlan{ + "t1": { + Name: "t1", + PKReferences: []string{"a"}, + Insert: "insert into t1 set c1=foo(:a_a), c2=:a_b", + Update: "update t1 set c2=:a_b where c1=(foo(:b_a))", + Delete: "delete from t1 where c1=(foo(:b_a))", }, }, }, }, { - // partial group by input: &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ Match: "t1", - Filter: "select c1, c2, c3 from t2 group by c3, c1", + Filter: "select a + b as c1, c as c2 from t1", }}, }, - plan: &PlayerPlan{ + plan: &TestPlayerPlan{ VStreamFilter: &binlogdatapb.Filter{ Rules: []*binlogdatapb.Rule{{ - Match: "t2", - Filter: "select c1, c2, c3 from t2", + Match: "t1", + Filter: "select a, b, c from t1", }}, }, - TablePlans: map[string]*TablePlan{ - "t2": { - Name: "t1", - ColExprs: []*ColExpr{{ - ColName: sqlparser.NewColIdent("c1"), - ColNum: 0, - IsGrouped: true, - }, { - ColName: sqlparser.NewColIdent("c2"), - ColNum: 1, - }, { - ColName: sqlparser.NewColIdent("c3"), - ColNum: 2, - IsGrouped: true, - }}, - OnInsert: InsertOndup, - }, - }, - }, - }, { - // full group by - input: &binlogdatapb.Filter{ - Rules: []*binlogdatapb.Rule{{ - Match: "t1", - Filter: "select c1, c2, c3 from t2 group by c3, c1, c2", - }}, - }, - plan: &PlayerPlan{ - VStreamFilter: &binlogdatapb.Filter{ - Rules: []*binlogdatapb.Rule{{ - Match: "t2", - Filter: "select c1, c2, c3 from t2", - }}, - }, - TablePlans: map[string]*TablePlan{ - "t2": { - Name: "t1", - ColExprs: []*ColExpr{{ - ColName: sqlparser.NewColIdent("c1"), - ColNum: 0, - IsGrouped: true, - }, { - ColName: sqlparser.NewColIdent("c2"), - ColNum: 1, - IsGrouped: true, - }, { - ColName: sqlparser.NewColIdent("c3"), - ColNum: 2, - IsGrouped: true, - }}, - OnInsert: InsertIgnore, + TablePlans: map[string]*TestTablePlan{ + "t1": { + Name: "t1", + PKReferences: []string{"a", "b"}, + Insert: "insert into t1 set c1=:a_a + :a_b, c2=:a_c", + Update: "update t1 set c2=:a_c where c1=(:b_a + :b_b)", + Delete: "delete from t1 where c1=(:b_a + :b_b)", }, }, }, @@ -370,24 +316,6 @@ func TestPlayerPlan(t *testing.T) { }}, }, err: "unexpected: sum(a + b)", - }, { - // unsupported func - input: &binlogdatapb.Filter{ - Rules: []*binlogdatapb.Rule{{ - Match: "t1", - Filter: "select foo(a) as c from t1", - }}, - }, - err: "unexpected: foo(a)", - }, { - // no complex expr in select - input: &binlogdatapb.Filter{ - Rules: []*binlogdatapb.Rule{{ - Match: "t1", - Filter: "select a + b from t1", - }}, - }, - err: "unexpected: a + b", }, { // no complex expr in group by input: &binlogdatapb.Filter{ @@ -417,8 +345,12 @@ func TestPlayerPlan(t *testing.T) { err: "group by expression is not allowed to reference an aggregate expression: a", }} + tableKeys := map[string][]string{ + "t1": {"c1"}, + } + for _, tcase := range testcases { - plan, err := buildPlayerPlan(tcase.input) + plan, err := buildPlayerPlan(tcase.input, tableKeys) gotPlan, _ := json.Marshal(plan) wantPlan, _ := json.Marshal(tcase.plan) if string(gotPlan) != string(wantPlan) { diff --git a/go/vt/vttablet/tabletmanager/vreplication/table_plan_builder.go b/go/vt/vttablet/tabletmanager/vreplication/table_plan_builder.go new file mode 100644 index 00000000000..4f6464cde9c --- /dev/null +++ b/go/vt/vttablet/tabletmanager/vreplication/table_plan_builder.go @@ -0,0 +1,498 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vreplication + +import ( + "fmt" + "sort" + "strings" + + "vitess.io/vitess/go/vt/sqlparser" + + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +type tablePlanBuilder struct { + name sqlparser.TableIdent + sendSelect *sqlparser.Select + selColumns map[string]bool + colExprs []*colExpr + onInsert insertType + pkCols []*colExpr +} + +// colExpr describes the processing to be performed to +// compute the value of the target table column. +type colExpr struct { + colName sqlparser.ColIdent + // operation==opExpr: full expression is set + // operation==opCount: nothing is set. + // operation==opSum: for 'sum(a)', expr is set to 'a'. + operation operation + expr sqlparser.Expr + // references contains all the column names referenced in the expression. + references map[string]bool + + isGrouped bool + isPK bool +} + +// operation is the opcode for the colExpr. +type operation int + +// The following values are the various colExpr opcodes. +const ( + opExpr = operation(iota) + opCount + opSum +) + +// insertType describes the type of insert statement to generate. +type insertType int + +// The following values are the various insert types. +const ( + insertNormal = insertType(iota) + insertOndup + insertIgnore +) + +func buildPlayerPlan(filter *binlogdatapb.Filter, tableKeys map[string][]string) (*PlayerPlan, error) { + plan := &PlayerPlan{ + VStreamFilter: &binlogdatapb.Filter{ + Rules: make([]*binlogdatapb.Rule, len(filter.Rules)), + }, + TablePlans: make(map[string]*TablePlan), + } + for i, rule := range filter.Rules { + if strings.HasPrefix(rule.Match, "/") { + plan.VStreamFilter.Rules[i] = rule + continue + } + sendRule, tablePlan, err := buildTablePlan(rule, tableKeys) + if err != nil { + return nil, err + } + plan.VStreamFilter.Rules[i] = sendRule + plan.TablePlans[sendRule.Match] = tablePlan + } + return plan, nil +} + +func buildTablePlan(rule *binlogdatapb.Rule, tableKeys map[string][]string) (*binlogdatapb.Rule, *TablePlan, error) { + sel, fromTable, err := analyzeSelectFrom(rule.Filter) + if err != nil { + return nil, nil, err + } + sendRule := &binlogdatapb.Rule{ + Match: fromTable, + } + + if expr, ok := sel.SelectExprs[0].(*sqlparser.StarExpr); ok { + if len(sel.SelectExprs) != 1 { + return nil, nil, fmt.Errorf("unexpected: %v", sqlparser.String(sel)) + } + if !expr.TableName.IsEmpty() { + return nil, nil, fmt.Errorf("unsupported qualifier for '*' expression: %v", sqlparser.String(expr)) + } + sendRule.Filter = rule.Filter + return sendRule, &TablePlan{Name: rule.Match}, nil + } + + tpb := &tablePlanBuilder{ + name: sqlparser.NewTableIdent(rule.Match), + sendSelect: &sqlparser.Select{ + From: sel.From, + Where: sel.Where, + }, + selColumns: make(map[string]bool), + } + + if err := tpb.analyzeExprs(sel.SelectExprs); err != nil { + return nil, nil, err + } + if err := tpb.analyzeGroupBy(sel.GroupBy); err != nil { + return nil, nil, err + } + if err := tpb.analyzePK(tableKeys); err != nil { + return nil, nil, err + } + + sendRule.Filter = sqlparser.String(tpb.sendSelect) + tablePlan := tpb.generate(tableKeys) + return sendRule, tablePlan, nil +} + +func buildTablePlanFromFields(tableName string, fields []*querypb.Field, tableKeys map[string][]string) (*TablePlan, error) { + tpb := &tablePlanBuilder{ + name: sqlparser.NewTableIdent(tableName), + } + for _, field := range fields { + colName := sqlparser.NewColIdent(field.Name) + cexpr := &colExpr{ + colName: colName, + expr: &sqlparser.ColName{ + Name: colName, + }, + references: map[string]bool{ + field.Name: true, + }, + } + tpb.colExprs = append(tpb.colExprs, cexpr) + } + if err := tpb.analyzePK(tableKeys); err != nil { + return nil, err + } + return tpb.generate(tableKeys), nil +} + +func (tpb *tablePlanBuilder) generate(tableKeys map[string][]string) *TablePlan { + refmap := make(map[string]bool) + for _, cexpr := range tpb.pkCols { + for k := range cexpr.references { + refmap[k] = true + } + } + pkrefs := make([]string, 0, len(refmap)) + for k := range refmap { + pkrefs = append(pkrefs, k) + } + sort.Strings(pkrefs) + return &TablePlan{ + Name: tpb.name.String(), + PKReferences: pkrefs, + Insert: tpb.generateInsertStatement(), + Update: tpb.generateUpdateStatement(), + Delete: tpb.generateDeleteStatement(), + } +} + +func analyzeSelectFrom(query string) (sel *sqlparser.Select, from string, err error) { + statement, err := sqlparser.Parse(query) + if err != nil { + return nil, "", err + } + sel, ok := statement.(*sqlparser.Select) + if !ok { + return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(statement)) + } + if sel.Distinct != "" { + return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel)) + } + if len(sel.From) > 1 { + return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel)) + } + node, ok := sel.From[0].(*sqlparser.AliasedTableExpr) + if !ok { + return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel)) + } + fromTable := sqlparser.GetTableName(node.Expr) + if fromTable.IsEmpty() { + return nil, "", fmt.Errorf("unexpected: %v", sqlparser.String(sel)) + } + return sel, fromTable.String(), nil +} + +func (tpb *tablePlanBuilder) analyzeExprs(selExprs sqlparser.SelectExprs) error { + for _, selExpr := range selExprs { + cexpr, err := tpb.analyzeExpr(selExpr) + if err != nil { + return err + } + tpb.colExprs = append(tpb.colExprs, cexpr) + } + return nil +} + +func (tpb *tablePlanBuilder) analyzeExpr(selExpr sqlparser.SelectExpr) (*colExpr, error) { + aliased, ok := selExpr.(*sqlparser.AliasedExpr) + if !ok { + return nil, fmt.Errorf("unexpected: %v", sqlparser.String(selExpr)) + } + as := aliased.As + if as.IsEmpty() { + as = sqlparser.NewColIdent(sqlparser.String(aliased.Expr)) + } + cexpr := &colExpr{ + colName: as, + references: make(map[string]bool), + } + if expr, ok := aliased.Expr.(*sqlparser.FuncExpr); ok { + if expr.Distinct || len(expr.Exprs) != 1 { + return nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) + } + if aliased.As.IsEmpty() { + return nil, fmt.Errorf("expression needs an alias: %v", sqlparser.String(expr)) + } + switch fname := expr.Name.Lowered(); fname { + case "count": + if _, ok := expr.Exprs[0].(*sqlparser.StarExpr); !ok { + return nil, fmt.Errorf("only count(*) is supported: %v", sqlparser.String(expr)) + } + cexpr.operation = opCount + return cexpr, nil + case "sum": + aInner, ok := expr.Exprs[0].(*sqlparser.AliasedExpr) + if !ok { + return nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) + } + innerCol, ok := aInner.Expr.(*sqlparser.ColName) + if !ok { + return nil, fmt.Errorf("unexpected: %v", sqlparser.String(expr)) + } + if !innerCol.Qualifier.IsEmpty() { + return nil, fmt.Errorf("unsupported qualifier for column: %v", sqlparser.String(innerCol)) + } + cexpr.operation = opSum + cexpr.expr = innerCol + tpb.addCol(innerCol.Name) + cexpr.references[innerCol.Name.Lowered()] = true + return cexpr, nil + } + } + err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + switch node := node.(type) { + case *sqlparser.ColName: + if !node.Qualifier.IsEmpty() { + return false, fmt.Errorf("unsupported qualifier for column: %v", sqlparser.String(node)) + } + tpb.addCol(node.Name) + cexpr.references[node.Name.Lowered()] = true + case *sqlparser.Subquery: + return false, fmt.Errorf("unsupported subquery: %v", sqlparser.String(node)) + case *sqlparser.FuncExpr: + if node.IsAggregate() { + return false, fmt.Errorf("unexpected: %v", sqlparser.String(node)) + } + } + return true, nil + }, aliased.Expr) + if err != nil { + return nil, err + } + cexpr.expr = aliased.Expr + return cexpr, nil +} + +func (tpb *tablePlanBuilder) addCol(ident sqlparser.ColIdent) { + if tpb.selColumns[ident.Lowered()] { + return + } + tpb.selColumns[ident.Lowered()] = true + tpb.sendSelect.SelectExprs = append(tpb.sendSelect.SelectExprs, &sqlparser.AliasedExpr{ + Expr: &sqlparser.ColName{Name: ident}, + }) +} + +func (tpb *tablePlanBuilder) analyzeGroupBy(groupBy sqlparser.GroupBy) error { + if groupBy == nil { + return nil + } + for _, expr := range groupBy { + colname, ok := expr.(*sqlparser.ColName) + if !ok { + return fmt.Errorf("unexpected: %v", sqlparser.String(expr)) + } + cexpr := tpb.findCol(colname.Name) + if cexpr == nil { + return fmt.Errorf("group by expression does not reference an alias in the select list: %v", sqlparser.String(expr)) + } + if cexpr.operation != opExpr { + return fmt.Errorf("group by expression is not allowed to reference an aggregate expression: %v", sqlparser.String(expr)) + } + cexpr.isGrouped = true + } + tpb.onInsert = insertIgnore + for _, cExpr := range tpb.colExprs { + if !cExpr.isGrouped { + tpb.onInsert = insertOndup + break + } + } + return nil +} + +func (tpb *tablePlanBuilder) analyzePK(tableKeys map[string][]string) error { + pkcols, ok := tableKeys[tpb.name.String()] + if !ok { + return fmt.Errorf("table %s not found in schema", tpb.name) + } + for _, pkcol := range pkcols { + cexpr := tpb.findCol(sqlparser.NewColIdent(pkcol)) + if cexpr == nil { + return fmt.Errorf("primary key column %s not found in select list", pkcol) + } + if cexpr.operation != opExpr { + return fmt.Errorf("primary key column %s is not allowed to reference an aggregate expression", pkcol) + } + cexpr.isPK = true + tpb.pkCols = append(tpb.pkCols, cexpr) + } + return nil +} + +func (tpb *tablePlanBuilder) findCol(name sqlparser.ColIdent) *colExpr { + for _, cexpr := range tpb.colExprs { + if cexpr.colName.Equal(name) { + return cexpr + } + } + return nil +} + +func (tpb *tablePlanBuilder) generateInsertStatement() *sqlparser.ParsedQuery { + bvf := &bindvarFormatter{} + buf := sqlparser.NewTrackedBuffer(bvf.formatter) + if tpb.onInsert == insertIgnore { + buf.Myprintf("insert ignore into %v set ", tpb.name) + } else { + buf.Myprintf("insert into %v set ", tpb.name) + } + tpb.generateInsertValues(buf, bvf) + if tpb.onInsert == insertOndup { + buf.Myprintf(" on duplicate key update ") + tpb.generateUpdate(buf, bvf, false /* before */, true /* after */) + } + return buf.ParsedQuery() +} + +func (tpb *tablePlanBuilder) generateUpdateStatement() *sqlparser.ParsedQuery { + if tpb.onInsert == insertIgnore { + return tpb.generateInsertStatement() + } + bvf := &bindvarFormatter{} + buf := sqlparser.NewTrackedBuffer(bvf.formatter) + buf.Myprintf("update %v set ", tpb.name) + tpb.generateUpdate(buf, bvf, true /* before */, true /* after */) + tpb.generateWhere(buf, bvf) + return buf.ParsedQuery() +} + +func (tpb *tablePlanBuilder) generateDeleteStatement() *sqlparser.ParsedQuery { + bvf := &bindvarFormatter{} + buf := sqlparser.NewTrackedBuffer(bvf.formatter) + switch tpb.onInsert { + case insertNormal: + buf.Myprintf("delete from %v", tpb.name) + tpb.generateWhere(buf, bvf) + case insertOndup: + buf.Myprintf("update %v set ", tpb.name) + tpb.generateUpdate(buf, bvf, true /* before */, false /* after */) + tpb.generateWhere(buf, bvf) + case insertIgnore: + return nil + } + return buf.ParsedQuery() +} + +func (tpb *tablePlanBuilder) generateInsertValues(buf *sqlparser.TrackedBuffer, bvf *bindvarFormatter) { + bvf.mode = bvAfter + separator := "" + for _, cexpr := range tpb.colExprs { + buf.Myprintf("%s%s=", separator, cexpr.colName.String()) + separator = ", " + switch cexpr.operation { + case opExpr: + buf.Myprintf("%v", cexpr.expr) + case opCount: + buf.WriteString("1") + case opSum: + buf.Myprintf("ifnull(%v, 0)", cexpr.expr) + } + } +} + +func (tpb *tablePlanBuilder) generateUpdate(buf *sqlparser.TrackedBuffer, bvf *bindvarFormatter, before, after bool) { + separator := "" + for _, cexpr := range tpb.colExprs { + if cexpr.isGrouped || cexpr.isPK { + continue + } + buf.Myprintf("%s%s=", separator, cexpr.colName.String()) + separator = ", " + switch cexpr.operation { + case opExpr: + if after { + bvf.mode = bvAfter + buf.Myprintf("%v", cexpr.expr) + } else { + buf.WriteString("null") + } + case opCount: + switch { + case before && after: + buf.Myprintf("%s", cexpr.colName.String()) + case before: + buf.Myprintf("%s-1", cexpr.colName.String()) + case after: + buf.Myprintf("%s+1", cexpr.colName.String()) + } + case opSum: + buf.Myprintf("%s", cexpr.colName.String()) + if before { + bvf.mode = bvBefore + buf.Myprintf("-ifnull(%v, 0)", cexpr.expr) + } + if after { + bvf.mode = bvAfter + buf.Myprintf("+ifnull(%v, 0)", cexpr.expr) + } + } + } +} + +func (tpb *tablePlanBuilder) generateWhere(buf *sqlparser.TrackedBuffer, bvf *bindvarFormatter) { + buf.WriteString(" where ") + bvf.mode = bvBefore + separator := "" + for _, cexpr := range tpb.pkCols { + if _, ok := cexpr.expr.(*sqlparser.ColName); ok { + buf.Myprintf("%s%s=%v", separator, cexpr.colName.String(), cexpr.expr) + } else { + // Parenthesize non-trivial expressions. + buf.Myprintf("%s%s=(%v)", separator, cexpr.colName.String(), cexpr.expr) + } + separator = " and " + } +} + +type bindvarFormatter struct { + mode bindvarMode +} + +type bindvarMode int + +const ( + bvNone = bindvarMode(iota) + bvBefore + bvAfter +) + +func (bvf *bindvarFormatter) formatter(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) { + if node, ok := node.(*sqlparser.ColName); ok { + switch bvf.mode { + case bvBefore: + buf.WriteArg(fmt.Sprintf(":b_%s", node.Name.String())) + return + case bvAfter: + buf.WriteArg(fmt.Sprintf(":a_%s", node.Name.String())) + return + } + } + node.Format(buf) +} diff --git a/go/vt/vttablet/tabletmanager/vreplication/vplayer.go b/go/vt/vttablet/tabletmanager/vreplication/vplayer.go index 5bf9bcafd81..b1344c6ae75 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vplayer.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vplayer.go @@ -25,12 +25,10 @@ import ( "golang.org/x/net/context" "vitess.io/vitess/go/mysql" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/binlog/binlogplayer" "vitess.io/vitess/go/vt/grpcclient" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/mysqlctl" - "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vttablet/tabletconn" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" @@ -70,6 +68,8 @@ type vplayer struct { timeOffsetNs int64 stopPos mysql.Position + tableKeys map[string][]string + // pplan is built based on the source Filter at the beginning. pplan *PlayerPlan // tplans[table] is built for each table based on pplan and schema info @@ -130,7 +130,12 @@ func (vp *vplayer) play(ctx context.Context) error { } log.Infof("Starting VReplication player id: %v, startPos: %v, stop: %v, source: %v, filter: %v", vp.id, startPos, vp.stopPos, vp.sourceTablet, vp.source) - plan, err := buildPlayerPlan(vp.source.Filter) + tableKeys, err := vp.buildTableKeys() + if err != nil { + return err + } + vp.tableKeys = tableKeys + plan, err := buildPlayerPlan(vp.source.Filter, tableKeys) if err != nil { return err } @@ -198,6 +203,22 @@ func (vp *vplayer) play(ctx context.Context) error { } } +func (vp *vplayer) buildTableKeys() (map[string][]string, error) { + schema, err := vp.mysqld.GetSchema(vp.dbClient.DBName(), []string{"/.*/"}, nil, false) + if err != nil { + return nil, err + } + tableKeys := make(map[string][]string) + for _, td := range schema.TableDefinitions { + if len(td.PrimaryKeyColumns) != 0 { + tableKeys[td.Name] = td.PrimaryKeyColumns + } else { + tableKeys[td.Name] = td.Columns + } + } + return tableKeys, nil +} + func (vp *vplayer) applyEvents(ctx context.Context, relay *relayLog) error { for { items, err := relay.Fetch() @@ -384,58 +405,21 @@ func (vp *vplayer) setState(state, message string) error { func (vp *vplayer) updatePlan(fieldEvent *binlogdatapb.FieldEvent) error { prelim := vp.pplan.TablePlans[fieldEvent.TableName] - tplan := &TablePlan{ - Name: fieldEvent.TableName, - } - if prelim != nil { - *tplan = *prelim - } - tplan.Fields = fieldEvent.Fields - - if tplan.ColExprs == nil { - tplan.ColExprs = make([]*ColExpr, len(tplan.Fields)) - for i, field := range tplan.Fields { - tplan.ColExprs[i] = &ColExpr{ - ColName: sqlparser.NewColIdent(field.Name), - ColNum: i, - } - } - } else { - for _, cExpr := range tplan.ColExprs { - if cExpr.ColNum >= len(tplan.Fields) { - // Unreachable code. - return fmt.Errorf("columns received from vreplication: %v, do not match expected: %v", tplan.Fields, tplan.ColExprs) - } + if prelim == nil { + prelim = &TablePlan{ + Name: fieldEvent.TableName, } } - - pkcols, err := vp.mysqld.GetPrimaryKeyColumns(vp.dbClient.DBName(), tplan.Name) - if err != nil { - return fmt.Errorf("error fetching pk columns for %s: %v", tplan.Name, err) - } - if len(pkcols) == 0 { - // If the table doesn't have a PK, then we treat all columns as PK. - pkcols, err = vp.mysqld.GetColumns(vp.dbClient.DBName(), tplan.Name) - if err != nil { - return fmt.Errorf("error fetching pk columns for %s: %v", tplan.Name, err) - } + if prelim.Insert != nil { + prelim.Fields = fieldEvent.Fields + vp.tplans[fieldEvent.TableName] = prelim + return nil } - for _, pkcol := range pkcols { - found := false - for i, cExpr := range tplan.ColExprs { - if cExpr.ColName.EqualString(pkcol) { - found = true - tplan.PKCols = append(tplan.PKCols, &ColExpr{ - ColName: cExpr.ColName, - ColNum: i, - }) - break - } - } - if !found { - return fmt.Errorf("primary key column %s missing from select list for table %s", pkcol, tplan.Name) - } + tplan, err := buildTablePlanFromFields(prelim.Name, fieldEvent.Fields, vp.tableKeys) + if err != nil { + return err } + tplan.Fields = fieldEvent.Fields vp.tplans[fieldEvent.TableName] = tplan return nil } @@ -446,7 +430,11 @@ func (vp *vplayer) applyRowEvent(ctx context.Context, rowEvent *binlogdatapb.Row return fmt.Errorf("unexpected event on table %s", rowEvent.TableName) } for _, change := range rowEvent.RowChanges { - for _, query := range tplan.GenerateStatements(change) { + queries, err := tplan.generateStatements(change) + if err != nil { + return err + } + for _, query := range queries { if err := vp.exec(ctx, query); err != nil { return err } @@ -490,13 +478,3 @@ func (vp *vplayer) exec(ctx context.Context, sql string) error { } return nil } - -func encodeValue(sql *sqlparser.TrackedBuffer, value sqltypes.Value) { - // This is currently a separate function because special handling - // may be needed for certain types. - // Previously, this function used to convert timestamp to the session - // time zone, but we now set the session timezone to UTC. So, the timestamp - // value we receive as UTC can be sent as is. - // TODO(sougou): handle BIT data type here? - value.EncodeSQL(sql) -} diff --git a/go/vt/vttablet/tabletmanager/vreplication/vplayer_test.go b/go/vt/vttablet/tabletmanager/vreplication/vplayer_test.go index 3af90ad45f4..228dba7207f 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vplayer_test.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vplayer_test.go @@ -129,7 +129,7 @@ func TestPlayerFilters(t *testing.T) { input: "insert into src2 values(1, 2, 3)", output: []string{ "begin", - "insert into dst2 set id=1, val1=2, sval2=3, rcount=1 on duplicate key update val1=2, sval2=sval2+3, rcount=rcount+1", + "insert into dst2 set id=1, val1=2, sval2=ifnull(3, 0), rcount=1 on duplicate key update val1=2, sval2=sval2+ifnull(3, 0), rcount=rcount+1", "/update _vt.vreplication set pos=", "commit", }, @@ -142,7 +142,7 @@ func TestPlayerFilters(t *testing.T) { input: "update src2 set val1=5, val2=1 where id=1", output: []string{ "begin", - "update dst2 set val1=5, sval2=sval2-3+1 where id=1", + "update dst2 set val1=5, sval2=sval2-ifnull(3, 0)+ifnull(1, 0), rcount=rcount where id=1", "/update _vt.vreplication set pos=", "commit", }, @@ -155,7 +155,7 @@ func TestPlayerFilters(t *testing.T) { input: "delete from src2 where id=1", output: []string{ "begin", - "update dst2 set val1=NULL, sval2=sval2-1, rcount=rcount-1 where id=1", + "update dst2 set val1=null, sval2=sval2-ifnull(1, 0), rcount=rcount-1 where id=1", "/update _vt.vreplication set pos=", "commit", }, @@ -310,7 +310,7 @@ func TestPlayerUpdates(t *testing.T) { }{{ // Start with all nulls input: "insert into t1 values(1, null, null, null)", - output: "insert into t1 set id=1, grouped=null, ungrouped=null, summed=0, rcount=1 on duplicate key update ungrouped=null, summed=summed, rcount=rcount+1", + output: "insert into t1 set id=1, grouped=null, ungrouped=null, summed=ifnull(null, 0), rcount=1 on duplicate key update ungrouped=null, summed=summed+ifnull(null, 0), rcount=rcount+1", table: "t1", data: [][]string{ {"1", "", "", "0", "1"}, @@ -318,7 +318,7 @@ func TestPlayerUpdates(t *testing.T) { }, { // null to null values input: "update t1 set grouped=1 where id=1", - output: "", + output: "update t1 set ungrouped=null, summed=summed-ifnull(null, 0)+ifnull(null, 0), rcount=rcount where id=1", table: "t1", data: [][]string{ {"1", "", "", "0", "1"}, @@ -326,7 +326,7 @@ func TestPlayerUpdates(t *testing.T) { }, { // null to non-null values input: "update t1 set ungrouped=1, summed=1 where id=1", - output: "update t1 set ungrouped=1, summed=summed+1 where id=1", + output: "update t1 set ungrouped=1, summed=summed-ifnull(null, 0)+ifnull(1, 0), rcount=rcount where id=1", table: "t1", data: [][]string{ {"1", "", "1", "1", "1"}, @@ -334,7 +334,7 @@ func TestPlayerUpdates(t *testing.T) { }, { // non-null to non-null values input: "update t1 set ungrouped=2, summed=2 where id=1", - output: "update t1 set ungrouped=2, summed=summed-1+2 where id=1", + output: "update t1 set ungrouped=2, summed=summed-ifnull(1, 0)+ifnull(2, 0), rcount=rcount where id=1", table: "t1", data: [][]string{ {"1", "", "2", "2", "1"}, @@ -342,7 +342,7 @@ func TestPlayerUpdates(t *testing.T) { }, { // non-null to null values input: "update t1 set ungrouped=null, summed=null where id=1", - output: "update t1 set ungrouped=null, summed=summed-2 where id=1", + output: "update t1 set ungrouped=null, summed=summed-ifnull(2, 0)+ifnull(null, 0), rcount=rcount where id=1", table: "t1", data: [][]string{ {"1", "", "", "0", "1"}, @@ -350,7 +350,7 @@ func TestPlayerUpdates(t *testing.T) { }, { // insert non-null values input: "insert into t1 values(2, 2, 3, 4)", - output: "insert into t1 set id=2, grouped=2, ungrouped=3, summed=4, rcount=1 on duplicate key update ungrouped=3, summed=summed+4, rcount=rcount+1", + output: "insert into t1 set id=2, grouped=2, ungrouped=3, summed=ifnull(4, 0), rcount=1 on duplicate key update ungrouped=3, summed=summed+ifnull(4, 0), rcount=rcount+1", table: "t1", data: [][]string{ {"1", "", "", "0", "1"}, @@ -359,7 +359,7 @@ func TestPlayerUpdates(t *testing.T) { }, { // delete non-null values input: "delete from t1 where id=2", - output: "update t1 set ungrouped=NULL, summed=summed-4, rcount=rcount-1 where id=2", + output: "update t1 set ungrouped=null, summed=summed-ifnull(4, 0), rcount=rcount-1 where id=2", table: "t1", data: [][]string{ {"1", "", "", "0", "1"}, @@ -416,9 +416,9 @@ func TestPlayerRowMove(t *testing.T) { }) expectDBClientQueries(t, []string{ "begin", - "insert into dst set val1=1, sval2=1, rcount=1 on duplicate key update sval2=sval2+1, rcount=rcount+1", - "insert into dst set val1=2, sval2=2, rcount=1 on duplicate key update sval2=sval2+2, rcount=rcount+1", - "insert into dst set val1=2, sval2=3, rcount=1 on duplicate key update sval2=sval2+3, rcount=rcount+1", + "insert into dst set val1=1, sval2=ifnull(1, 0), rcount=1 on duplicate key update sval2=sval2+ifnull(1, 0), rcount=rcount+1", + "insert into dst set val1=2, sval2=ifnull(2, 0), rcount=1 on duplicate key update sval2=sval2+ifnull(2, 0), rcount=rcount+1", + "insert into dst set val1=2, sval2=ifnull(3, 0), rcount=1 on duplicate key update sval2=sval2+ifnull(3, 0), rcount=rcount+1", "/update _vt.vreplication set pos=", "commit", }) @@ -432,8 +432,8 @@ func TestPlayerRowMove(t *testing.T) { }) expectDBClientQueries(t, []string{ "begin", - "update dst set sval2=sval2-3, rcount=rcount-1 where val1=2", - "insert into dst set val1=1, sval2=4, rcount=1 on duplicate key update sval2=sval2+4, rcount=rcount+1", + "update dst set sval2=sval2-ifnull(3, 0), rcount=rcount-1 where val1=2", + "insert into dst set val1=1, sval2=ifnull(4, 0), rcount=1 on duplicate key update sval2=sval2+ifnull(4, 0), rcount=rcount+1", "/update _vt.vreplication set pos=", "commit", })