From 66c5fb290d6c35b72c1cd7edd1dfb3adc7609380 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sun, 26 Feb 2023 14:16:56 +0800 Subject: [PATCH 01/19] hash join --- pkg/runtime/optimize/dml/select.go | 102 ++++++++++++------ pkg/runtime/plan/dml/hash_join.go | 161 +++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+), 29 deletions(-) create mode 100644 pkg/runtime/plan/dml/hash_join.go diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index 945b2827..8f9dda07 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -358,7 +358,7 @@ func handleGroupBy(parentPlan proto.Plan, stmt *ast.SelectStatement) (proto.Plan func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectStatement) (proto.Plan, error) { join := stmt.From[0].Source().(*ast.JoinNode) - compute := func(tableSource *ast.TableSourceNode) (database, alias string, shardList []string, err error) { + compute := func(tableSource *ast.TableSourceNode) (database string, shardsMap map[string][]string, alias string, err error) { table := tableSource.TableName() if table == nil { err = errors.New("must table, not statement or join node") @@ -371,57 +371,101 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt if err != nil { return } + + shardsMap = make(map[string][]string, len(shards)) + // table no shard if shards == nil { - shardList = append(shardList, table.Suffix()) - return - } - // table shard more than one db - if len(shards) > 1 { - err = errors.New("not support more than one db") + shardsMap[database] = append(shardsMap[database], table.Suffix()) return } - for k, v := range shards { - database = k - shardList = v - } - + // table has shard + shardsMap = shards if alias == "" { alias = table.Suffix() } - return } - dbLeft, aliasLeft, shardLeft, err := compute(join.Left) + dbLeft, shardsLeft, aliasLeft, err := compute(join.Left) if err != nil { return nil, err } - dbRight, aliasRight, shardRight, err := compute(join.Right) + dbRight, shardsRight, aliasRight, err := compute(join.Right) if err != nil { return nil, err } - if dbLeft != "" && dbRight != "" && dbLeft != dbRight { - return nil, errors.New("not support more than one db") + // one db + if dbLeft == dbRight && len(shardsLeft) == 1 && len(shardsRight) == 1 { + joinPan := &dml.SimpleJoinPlan{ + Left: &dml.JoinTable{ + Tables: shardsLeft[dbLeft], + Alias: aliasLeft, + }, + Join: join, + Right: &dml.JoinTable{ + Tables: shardsRight[dbRight], + Alias: aliasRight, + }, + Stmt: o.Stmt.(*ast.SelectStatement), + } + joinPan.BindArgs(o.Args) + return joinPan, nil } - joinPan := &dml.SimpleJoinPlan{ - Left: &dml.JoinTable{ - Tables: shardLeft, - Alias: aliasLeft, - }, + // multiple shards & hash join + hashJoinPlan := &dml.HashJoinPlan{ Join: join, - Right: &dml.JoinTable{ - Tables: shardRight, - Alias: aliasRight, - }, - Stmt: o.Stmt.(*ast.SelectStatement), + Stmt: stmt, + } + + buildShards := shardsLeft + probeShards := shardsRight + if len(shardsLeft) > len(shardsRight) { + buildShards = shardsRight + probeShards = shardsLeft + } + + rewriteToSingle := func(shards map[string][]string) []proto.Plan { + // todo 过滤下select where条件对应的表 + selectStmt := &ast.SelectStatement{ + Select: stmt.Select, + From: ast.FromNode{join.Right}, + Where: stmt.Where, + } + + plans := make([]proto.Plan, 0, len(shards)) + for k, v := range shards { + next := &dml.SimpleQueryPlan{ + Database: k, + Tables: v, + Stmt: selectStmt, + } + next.BindArgs(o.Args) + plans = append(plans, next) + } + return plans + } + + hashJoinPlan.BuildPlans = rewriteToSingle(buildShards) + hashJoinPlan.ProbePlans = rewriteToSingle(probeShards) + + // todo 需要兼容多个on条件 ast.LogicalExpressionNode + onExpression := join.On.(*ast.PredicateExpressionNode).P.(*ast.BinaryComparisonPredicateNode) + onLeft := onExpression.Left.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + onRight := onExpression.Right.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + + if onLeft[0] == aliasLeft { + hashJoinPlan.BuildKey[0] = onLeft[1] + } + + if onRight[0] == aliasRight { + hashJoinPlan.ProbeKey[0] = onRight[1] } - joinPan.BindArgs(o.Args) - return joinPan, nil + return hashJoinPlan, nil } func getSelectFlag(ru *rule.Rule, stmt *ast.SelectStatement) (flag uint32) { diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go new file mode 100644 index 00000000..257fdce0 --- /dev/null +++ b/pkg/runtime/plan/dml/hash_join.go @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" + "github.com/arana-db/arana/third_party/base58" + "github.com/cespare/xxhash/v2" + "github.com/pkg/errors" + "io" +) + +type HashJoinPlan struct { + BuildPlans []proto.Plan + ProbePlans []proto.Plan + + BuildKey []string + ProbeKey []string + hashArea map[string]proto.Row + + Join *ast.JoinNode + Stmt *ast.SelectStatement +} + +func (h *HashJoinPlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (h *HashJoinPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + ctx, span := plan.Tracer.Start(ctx, "HashJoinPlan.ExecIn") + defer span.End() + + // build stage + buildDs, err := h.build(ctx, conn) + if err != nil { + return nil, errors.WithStack(err) + } + + // probe stage + probeDs, err := h.probe(ctx, conn, buildDs) + if err != nil { + return nil, errors.WithStack(err) + } + + resultx.New(resultx.WithDataset(probeDs)) + return nil, nil +} + +func (h *HashJoinPlan) queryAggregate(ctx context.Context, conn proto.VConn, plans []proto.Plan) (proto.Dataset, error) { + var generators []dataset.GenerateFunc + for _, it := range plans { + it := it + generators = append(generators, func() (proto.Dataset, error) { + res, err := it.ExecIn(ctx, conn) + if err != nil { + return nil, errors.WithStack(err) + } + return res.Dataset() + }) + } + + ds, err := dataset.Fuse(generators[0], generators[1:]...) + if err != nil { + return nil, err + } + + // todo 将所有结果聚合 + return ds, nil +} + +func (h *HashJoinPlan) build(ctx context.Context, conn proto.VConn) (proto.Dataset, error) { + ds, err := h.queryAggregate(ctx, conn, h.BuildPlans) + if err != nil { + return nil, errors.WithStack(err) + } + + cn := h.BuildKey[0] + xh := xxhash.New() + // build map + for { + xh.Reset() + next, err := ds.Next() + if err == io.EOF { + break + } + + keyedRow := next.(proto.KeyedRow) + value, err := keyedRow.Get(cn) + if err != nil { + return nil, errors.WithStack(err) + } + + _, _ = xh.WriteString(value.String()) + h.hashArea[base58.Encode(xh.Sum(nil))] = next + } + + return ds, nil +} + +func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs proto.Dataset) (proto.Dataset, error) { + ds, err := h.queryAggregate(ctx, conn, h.ProbePlans) + if err != nil { + return nil, errors.WithStack(err) + } + + probeMapFunc := func(row proto.Row, columnName string) proto.Row { + keyedRow := row.(proto.KeyedRow) + value, _ := keyedRow.Get(columnName) + xh := xxhash.New() + _, _ = xh.WriteString(value.String()) + return h.hashArea[base58.Encode(xh.Sum(nil))] + } + + cn := h.ProbeKey[0] + filterFunc := func(row proto.Row) bool { + findRow := probeMapFunc(row, cn) + return findRow == nil + } + + bFields, _ := buildDs.Fields() + // aggregate fields + aggregateFieldsFunc := func(fields []proto.Field) []proto.Field { + return append(bFields, fields...) + } + + // aggregate row + fields, _ := ds.Fields() + transformFunc := func(row proto.Row) (proto.Row, error) { + dest := make([]proto.Value, len(fields)) + _ = row.Scan(dest) + + matchRow := probeMapFunc(row, cn) + bDest := make([]proto.Value, len(bFields)) + _ = matchRow.Scan(bDest) + + return rows.NewBinaryVirtualRow(append(bFields, fields...), append(bDest, dest...)), nil + } + + return dataset.Pipe(ds, dataset.Filter(filterFunc), dataset.Map(aggregateFieldsFunc, transformFunc)), nil +} From 1b4384fe27adc34fcf26b1925545647b9299e25f Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sun, 26 Feb 2023 14:19:09 +0800 Subject: [PATCH 02/19] add remark --- pkg/runtime/plan/dml/hash_join.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index 257fdce0..5da40b7d 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -157,5 +157,6 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot return rows.NewBinaryVirtualRow(append(bFields, fields...), append(bDest, dest...)), nil } + // filter match row & aggregate fields and row return dataset.Pipe(ds, dataset.Filter(filterFunc), dataset.Map(aggregateFieldsFunc, transformFunc)), nil } From f24d3b41c53606ecef8da9d9354479c8cd51a508 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sun, 26 Feb 2023 14:41:24 +0800 Subject: [PATCH 03/19] return resultx --- pkg/runtime/plan/dml/hash_join.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index 5da40b7d..6044f5c4 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -63,8 +63,7 @@ func (h *HashJoinPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Resu return nil, errors.WithStack(err) } - resultx.New(resultx.WithDataset(probeDs)) - return nil, nil + return resultx.New(resultx.WithDataset(probeDs)), nil } func (h *HashJoinPlan) queryAggregate(ctx context.Context, conn proto.VConn, plans []proto.Plan) (proto.Dataset, error) { From 79126932ef07499fb2a8d6cfbbda5be4ab500090 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sun, 26 Feb 2023 15:04:39 +0800 Subject: [PATCH 04/19] filter func --- pkg/runtime/plan/dml/hash_join.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index 6044f5c4..c9df6629 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -134,7 +134,7 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot cn := h.ProbeKey[0] filterFunc := func(row proto.Row) bool { findRow := probeMapFunc(row, cn) - return findRow == nil + return findRow != nil } bFields, _ := buildDs.Fields() From 7256cbee9198782d0b4a0a733a9db48ec1f439b4 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sat, 18 Mar 2023 14:05:18 +0800 Subject: [PATCH 05/19] optimize plan --- pkg/runtime/optimize/dml/select.go | 135 +++++++++++++++++------------ pkg/runtime/plan/dml/hash_join.go | 50 +++++------ 2 files changed, 103 insertions(+), 82 deletions(-) diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index fcc32f27..13d1af36 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -19,6 +19,7 @@ package dml import ( "context" + "github.com/arana-db/parser" "strings" ) @@ -415,7 +416,7 @@ func handleGroupBy(parentPlan proto.Plan, stmt *ast.SelectStatement) (proto.Plan // optimizeJoin ony support a join b in one db. // DEPRECATED: reimplement in the future func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectStatement) (proto.Plan, error) { - compute := func(tableSource *ast.TableSourceItem) (database, alias string, shardList []string, err error) { + compute := func(tableSource *ast.TableSourceItem) (database, alias string, shardsMap map[string][]string, err error) { table := tableSource.Source.(ast.TableName) if table == nil { err = errors.New("must table, not statement or join node") @@ -445,90 +446,114 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt return } - //dbLeft, shardsLeft, aliasLeft, err := compute(join.Left) from := stmt.From[0] - - dbLeft, aliasLeft, shardLeft, err := compute(&from.TableSourceItem) + dbLeft, aliasLeft, shardsLeft, err := compute(&from.TableSourceItem) if err != nil { return nil, err } - //dbRight, shardsRight, aliasRight, err := compute(join.Right) - dbRight, aliasRight, shardRight, err := compute(from.Joins[0].Target) + + dbRight, aliasRight, shardsRight, err := compute(from.Joins[0].Target) if err != nil { return nil, err } - if dbLeft != "" && dbRight != "" && dbLeft != dbRight { - return nil, errors.New("not support more than one db") - } + //if dbLeft != "" && dbRight != "" && dbLeft != dbRight { + // return nil, errors.New("not support more than one db") + //} - joinPan := &dml.SimpleJoinPlan{ - Left: &dml.JoinTable{ - Tables: shardLeft, - Alias: aliasLeft, - }, - Join: from.Joins[0], - Right: &dml.JoinTable{ - Tables: shardRight, - Alias: aliasRight, - }, - Stmt: o.Stmt.(*ast.SelectStatement), + // one db + if dbLeft == dbRight && len(shardsLeft) == 1 && len(shardsRight) == 1 { + joinPan := &dml.SimpleJoinPlan{ + Left: &dml.JoinTable{ + Tables: shardsLeft[dbLeft], + Alias: aliasLeft, + }, + Join: from.Joins[0], + Right: &dml.JoinTable{ + Tables: shardsRight[dbRight], + Alias: aliasRight, + }, + Stmt: o.Stmt.(*ast.SelectStatement), + } + joinPan.BindArgs(o.Args) + return joinPan, nil } - joinPan.BindArgs(o.Args) - return joinPan, nil -} + //multiple shards & do hash join + hashJoinPlan := &dml.HashJoinPlan{} - // multiple shards & hash join - hashJoinPlan := &dml.HashJoinPlan{ - Join: join, - Stmt: stmt, - } - - buildShards := shardsLeft - probeShards := shardsRight - if len(shardsLeft) > len(shardsRight) { - buildShards = shardsRight - probeShards = shardsLeft - } + //todo small table join large table - rewriteToSingle := func(shards map[string][]string) []proto.Plan { - // todo 过滤下select where条件对应的表 + rewriteToSingle := func(tableSource ast.TableSourceItem, shards map[string][]string) (proto.Plan, error) { selectStmt := &ast.SelectStatement{ Select: stmt.Select, - From: ast.FromNode{join.Right}, - Where: stmt.Where, + From: ast.FromNode{ + &ast.TableSourceNode{ + TableSourceItem: tableSource, + }, + }, } - plans := make([]proto.Plan, 0, len(shards)) - for k, v := range shards { - next := &dml.SimpleQueryPlan{ - Database: k, - Tables: v, - Stmt: selectStmt, - } - next.BindArgs(o.Args) - plans = append(plans, next) + var ( + err error + optimizer proto.Optimizer + plan proto.Plan + sb strings.Builder + ) + err = selectStmt.Restore(ast.RestoreDefault, &sb, nil) + if err != nil { + return nil, err + } + + p := parser.New() + stmtNode, err := p.ParseOneStmt(sb.String(), "", "") + if err != nil { + return nil, err } - return plans + + optimizer, err = optimize.NewOptimizer(o.Rule, o.Hints, stmtNode, o.Args) + if err != nil { + return nil, err + } + + plan, err = optimizeSelect(ctx, optimizer.(*optimize.Optimizer)) + if err != nil { + return nil, err + } + return plan, nil + } + + leftPlan, err := rewriteToSingle(from.TableSourceItem, shardsLeft) + if err != nil { + return nil, err } + hashJoinPlan.BuildPlan = leftPlan - hashJoinPlan.BuildPlans = rewriteToSingle(buildShards) - hashJoinPlan.ProbePlans = rewriteToSingle(probeShards) + rightPlan, err := rewriteToSingle(*from.Joins[0].Target, shardsRight) + if err != nil { + return nil, err + } + hashJoinPlan.ProbePlan = rightPlan + + onExpression, ok := from.Joins[0].On.(*ast.PredicateExpressionNode).P.(*ast.BinaryComparisonPredicateNode) + // todo support more than one 'ON' condition ast.LogicalExpressionNode + if !ok { + return nil, errors.New("not support more than one 'ON' condition") + } - // todo 需要兼容多个on条件 ast.LogicalExpressionNode - onExpression := join.On.(*ast.PredicateExpressionNode).P.(*ast.BinaryComparisonPredicateNode) onLeft := onExpression.Left.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) onRight := onExpression.Right.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) if onLeft[0] == aliasLeft { - hashJoinPlan.BuildKey[0] = onLeft[1] + hashJoinPlan.BuildKey = append(hashJoinPlan.BuildKey, onLeft[1]) } if onRight[0] == aliasRight { - hashJoinPlan.ProbeKey[0] = onRight[1] + hashJoinPlan.ProbeKey = append(hashJoinPlan.ProbeKey, onRight[1]) } + //todo order by, limit, group by, having etc.. + return hashJoinPlan, nil } diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index c9df6629..6c94183f 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -23,7 +23,6 @@ import ( "github.com/arana-db/arana/pkg/mysql/rows" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/resultx" - "github.com/arana-db/arana/pkg/runtime/ast" "github.com/arana-db/arana/pkg/runtime/plan" "github.com/arana-db/arana/third_party/base58" "github.com/cespare/xxhash/v2" @@ -32,15 +31,12 @@ import ( ) type HashJoinPlan struct { - BuildPlans []proto.Plan - ProbePlans []proto.Plan + BuildPlan proto.Plan + ProbePlan proto.Plan BuildKey []string ProbeKey []string hashArea map[string]proto.Row - - Join *ast.JoinNode - Stmt *ast.SelectStatement } func (h *HashJoinPlan) Type() proto.PlanType { @@ -66,36 +62,29 @@ func (h *HashJoinPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Resu return resultx.New(resultx.WithDataset(probeDs)), nil } -func (h *HashJoinPlan) queryAggregate(ctx context.Context, conn proto.VConn, plans []proto.Plan) (proto.Dataset, error) { - var generators []dataset.GenerateFunc - for _, it := range plans { - it := it - generators = append(generators, func() (proto.Dataset, error) { - res, err := it.ExecIn(ctx, conn) - if err != nil { - return nil, errors.WithStack(err) - } - return res.Dataset() - }) - } - - ds, err := dataset.Fuse(generators[0], generators[1:]...) +func (h *HashJoinPlan) queryAggregate(ctx context.Context, conn proto.VConn, plan proto.Plan) (proto.Result, error) { + result, err := plan.ExecIn(ctx, conn) if err != nil { return nil, err } // todo 将所有结果聚合 - return ds, nil + return result, nil } func (h *HashJoinPlan) build(ctx context.Context, conn proto.VConn) (proto.Dataset, error) { - ds, err := h.queryAggregate(ctx, conn, h.BuildPlans) + res, err := h.queryAggregate(ctx, conn, h.BuildPlan) if err != nil { return nil, errors.WithStack(err) } + ds, err := res.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } cn := h.BuildKey[0] xh := xxhash.New() + h.hashArea = make(map[string]proto.Row) // build map for { xh.Reset() @@ -118,7 +107,12 @@ func (h *HashJoinPlan) build(ctx context.Context, conn proto.VConn) (proto.Datas } func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs proto.Dataset) (proto.Dataset, error) { - ds, err := h.queryAggregate(ctx, conn, h.ProbePlans) + res, err := h.queryAggregate(ctx, conn, h.ProbePlan) + if err != nil { + return nil, errors.WithStack(err) + } + + ds, err := res.Dataset() if err != nil { return nil, errors.WithStack(err) } @@ -137,10 +131,10 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot return findRow != nil } - bFields, _ := buildDs.Fields() + buildFields, _ := buildDs.Fields() // aggregate fields aggregateFieldsFunc := func(fields []proto.Field) []proto.Field { - return append(bFields, fields...) + return append(buildFields, fields...) } // aggregate row @@ -150,12 +144,14 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot _ = row.Scan(dest) matchRow := probeMapFunc(row, cn) - bDest := make([]proto.Value, len(bFields)) + bDest := make([]proto.Value, len(buildFields)) _ = matchRow.Scan(bDest) - return rows.NewBinaryVirtualRow(append(bFields, fields...), append(bDest, dest...)), nil + return rows.NewBinaryVirtualRow(append(buildFields, fields...), append(bDest, dest...)), nil } + // todo left/right join + // filter match row & aggregate fields and row return dataset.Pipe(ds, dataset.Filter(filterFunc), dataset.Map(aggregateFieldsFunc, transformFunc)), nil } From cbd4bbc47d7c4987e1c49091d5f97a83ba7fac70 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sun, 19 Mar 2023 22:02:11 +0800 Subject: [PATCH 06/19] optimize plan --- pkg/runtime/plan/dml/hash_join.go | 37 ++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index 6c94183f..8952bdf1 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -18,8 +18,10 @@ package dml import ( + "bytes" "context" "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/mysql/rows" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/resultx" @@ -67,8 +69,6 @@ func (h *HashJoinPlan) queryAggregate(ctx context.Context, conn proto.VConn, pla if err != nil { return nil, err } - - // todo 将所有结果聚合 return result, nil } @@ -137,6 +137,8 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot return append(buildFields, fields...) } + // todo left/right join + // aggregate row fields, _ := ds.Fields() transformFunc := func(row proto.Row) (proto.Row, error) { @@ -144,14 +146,33 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot _ = row.Scan(dest) matchRow := probeMapFunc(row, cn) - bDest := make([]proto.Value, len(buildFields)) - _ = matchRow.Scan(bDest) - - return rows.NewBinaryVirtualRow(append(buildFields, fields...), append(bDest, dest...)), nil + buildDest := make([]proto.Value, len(buildFields)) + _ = matchRow.Scan(buildDest) + + resFields := append(buildFields, fields...) + resDest := append(buildDest, dest...) + + var b bytes.Buffer + if row.IsBinary() { + newRow := rows.NewBinaryVirtualRow(resFields, resDest) + _, err := newRow.WriteTo(&b) + if err != nil { + return nil, err + } + + br := mysql.NewBinaryRow(fields, b.Bytes()) + return br, nil + } else { + newRow := rows.NewTextVirtualRow(resFields, resDest) + _, err := newRow.WriteTo(&b) + if err != nil { + return nil, err + } + + return mysql.NewTextRow(fields, b.Bytes()), nil + } } - // todo left/right join - // filter match row & aggregate fields and row return dataset.Pipe(ds, dataset.Filter(filterFunc), dataset.Map(aggregateFieldsFunc, transformFunc)), nil } From b5377743a128440f2e2f749954a4d925a4dacce7 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sun, 19 Mar 2023 23:08:58 +0800 Subject: [PATCH 07/19] optimize --- pkg/runtime/optimize/dml/select.go | 7 ++++--- pkg/runtime/plan/dml/hash_join.go | 9 ++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index 13d1af36..92fe79bc 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -425,6 +425,10 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt alias = tableSource.Alias database = table.Prefix() + if alias == "" { + alias = table.Suffix() + } + shards, err := o.ComputeShards(ctx, table, nil, o.Args) if err != nil { return @@ -440,9 +444,6 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt // table has shard shardsMap = shards - if alias == "" { - alias = table.Suffix() - } return } diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index 8952bdf1..1001ca1d 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -120,9 +120,12 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot probeMapFunc := func(row proto.Row, columnName string) proto.Row { keyedRow := row.(proto.KeyedRow) value, _ := keyedRow.Get(columnName) - xh := xxhash.New() - _, _ = xh.WriteString(value.String()) - return h.hashArea[base58.Encode(xh.Sum(nil))] + if value != nil { + xh := xxhash.New() + _, _ = xh.WriteString(value.String()) + return h.hashArea[base58.Encode(xh.Sum(nil))] + } + return nil } cn := h.ProbeKey[0] From 057c77dae0db206967c97fce19e8d3728022876b Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sun, 19 Mar 2023 23:15:50 +0800 Subject: [PATCH 08/19] check nil --- pkg/runtime/plan/dml/hash_join.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index 1001ca1d..f672e5b8 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -99,8 +99,10 @@ func (h *HashJoinPlan) build(ctx context.Context, conn proto.VConn) (proto.Datas return nil, errors.WithStack(err) } - _, _ = xh.WriteString(value.String()) - h.hashArea[base58.Encode(xh.Sum(nil))] = next + if value != nil { + _, _ = xh.WriteString(value.String()) + h.hashArea[base58.Encode(xh.Sum(nil))] = next + } } return ds, nil From 381bee42f2b80984c37a2b4d70ca195b076e4627 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sun, 16 Apr 2023 18:10:17 +0800 Subject: [PATCH 09/19] left/right join --- pkg/runtime/ast/select_element.go | 7 + pkg/runtime/optimize/dml/select.go | 220 +++++++++++++++++++++++++---- pkg/runtime/plan/dml/hash_join.go | 38 +++-- 3 files changed, 224 insertions(+), 41 deletions(-) diff --git a/pkg/runtime/ast/select_element.go b/pkg/runtime/ast/select_element.go index 97ba0814..de396341 100644 --- a/pkg/runtime/ast/select_element.go +++ b/pkg/runtime/ast/select_element.go @@ -267,6 +267,13 @@ func (s *SelectElementColumn) Suffix() string { return s.Name[len(s.Name)-1] } +func (s *SelectElementColumn) Prefix() string { + if len(s.Name) < 2 { + return "" + } + return s.Name[len(s.Name)-2] +} + func (s *SelectElementColumn) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { if err := ColumnNameExpressionAtom(s.Name).Restore(flag, sb, args); err != nil { return errors.WithStack(err) diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index 92fe79bc..34de4669 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -453,15 +453,12 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt return nil, err } - dbRight, aliasRight, shardsRight, err := compute(from.Joins[0].Target) + join := from.Joins[0] + dbRight, aliasRight, shardsRight, err := compute(join.Target) if err != nil { return nil, err } - //if dbLeft != "" && dbRight != "" && dbLeft != dbRight { - // return nil, errors.New("not support more than one db") - //} - // one db if dbLeft == dbRight && len(shardsLeft) == 1 && len(shardsRight) == 1 { joinPan := &dml.SimpleJoinPlan{ @@ -481,11 +478,34 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt } //multiple shards & do hash join - hashJoinPlan := &dml.HashJoinPlan{} + hashJoinPlan := &dml.HashJoinPlan{ + Stmt: stmt, + } + + onExpression, ok := from.Joins[0].On.(*ast.PredicateExpressionNode).P.(*ast.BinaryComparisonPredicateNode) + // todo support more 'ON' condition ast.LogicalExpressionNode + if !ok { + return nil, errors.New("not support more than one 'ON' condition") + } + + onLeft := onExpression.Left.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + onRight := onExpression.Right.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + + leftKey := "" + if onLeft.Prefix() == aliasLeft { + leftKey = onLeft.Suffix() + } - //todo small table join large table + rightKey := "" + if onRight.Prefix() == aliasRight { + rightKey = onRight.Suffix() + } + + if len(leftKey) == 0 || len(rightKey) == 0 { + return nil, errors.Errorf("not found buildKey or probeKey") + } - rewriteToSingle := func(tableSource ast.TableSourceItem, shards map[string][]string) (proto.Plan, error) { + rewriteToSingle := func(tableSource ast.TableSourceItem, shards map[string][]string, onKey string) (proto.Plan, error) { selectStmt := &ast.SelectStatement{ Select: stmt.Select, From: ast.FromNode{ @@ -494,9 +514,47 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt }, }, } + table := tableSource.Source.(ast.TableName) + actualTb := table.Suffix() + aliasTb := tableSource.Alias + + tb0 := actualTb + for _, tb := range shards { + if len(tb) > 1 { + vt := o.Rule.MustVTable(tb0) + _, tb0, _ = vt.Topology().Smallest() + break + } + } + if _, ok = stmt.Select[0].(*ast.SelectElementAll); ok && len(stmt.Select) == 1 { + if err = rewriteSelectStatement(ctx, selectStmt, tb0); err != nil { + return nil, err + } + + selectStmt.Select = append(selectStmt.Select, ast.NewSelectElementColumn([]string{onKey}, "")) + } else { + metadata, err := loadMetadataByTable(ctx, tb0) + if err != nil { + return nil, err + } + + selectColumn := selectStmt.Select + var selectElements []ast.SelectElement + for _, element := range selectColumn { + e, ok := element.(*ast.SelectElementColumn) + if ok { + for _, c := range metadata.ColumnNames { + if (aliasTb == e.Prefix() || actualTb == e.Prefix()) && c == e.Suffix() { + selectElements = append(selectElements, ast.NewSelectElementColumn([]string{c}, "")) + } + } + } + } + selectElements = append(selectElements, ast.NewSelectElementColumn([]string{onKey}, "")) + selectStmt.Select = selectElements + } var ( - err error optimizer proto.Optimizer plan proto.Plan sb strings.Builder @@ -524,38 +582,116 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt return plan, nil } - leftPlan, err := rewriteToSingle(from.TableSourceItem, shardsLeft) + leftPlan, err := rewriteToSingle(from.TableSourceItem, shardsLeft, leftKey) if err != nil { return nil, err } - hashJoinPlan.BuildPlan = leftPlan - rightPlan, err := rewriteToSingle(*from.Joins[0].Target, shardsRight) + rightPlan, err := rewriteToSingle(*from.Joins[0].Target, shardsRight, rightKey) if err != nil { return nil, err } - hashJoinPlan.ProbePlan = rightPlan - onExpression, ok := from.Joins[0].On.(*ast.PredicateExpressionNode).P.(*ast.BinaryComparisonPredicateNode) - // todo support more than one 'ON' condition ast.LogicalExpressionNode - if !ok { - return nil, errors.New("not support more than one 'ON' condition") + setPlan := func(plan *dml.HashJoinPlan, buildPlan, probePlan proto.Plan, buildKey, probeKey string) { + plan.BuildKey = buildKey + plan.ProbeKey = probeKey + plan.BuildPlan = buildPlan + plan.ProbePlan = probePlan } - onLeft := onExpression.Left.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) - onRight := onExpression.Right.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + typ := join.Typ + if typ.String() == "INNER" { + setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey) + hashJoinPlan.IsFilterProbeRow = true + } else { + hashJoinPlan.IsFilterProbeRow = false + if typ.String() == "LEFT" { + setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey) + } else if typ.String() == "RIGHT" { + setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey) + } else { + return nil, errors.New("not support Join Type") + } + } + + var tmpPlan proto.Plan + tmpPlan = hashJoinPlan + + var ( + analysis selectResult + scanner = newSelectScanner(stmt, o.Args) + tableName = from.Source.(ast.TableName) + vt = o.Rule.MustVTable(tableName.Suffix()) + ) + + _, tb, _ := vt.Topology().Smallest() + if err = rewriteSelectStatement(ctx, stmt, tb); err != nil { + return nil, errors.WithStack(err) + } + + if err = scanner.scan(&analysis); err != nil { + return nil, errors.WithStack(err) + } + + // check if order-by exists + if len(analysis.orders) > 0 { + var ( + sb strings.Builder + orderByItems = make([]dataset.OrderByItem, 0, len(analysis.orders)) + ) - if onLeft[0] == aliasLeft { - hashJoinPlan.BuildKey = append(hashJoinPlan.BuildKey, onLeft[1]) + for _, it := range analysis.orders { + var next dataset.OrderByItem + next.Desc = it.Desc + if alias := it.Alias(); len(alias) > 0 { + next.Column = alias + } else { + switch prev := it.Prev().(type) { + case *ast.SelectElementColumn: + next.Column = prev.Suffix() + default: + if err = it.Restore(ast.RestoreWithoutAlias, &sb, nil); err != nil { + return nil, errors.WithStack(err) + } + next.Column = sb.String() + sb.Reset() + } + } + orderByItems = append(orderByItems, next) + } + tmpPlan = &dml.OrderPlan{ + ParentPlan: tmpPlan, + OrderByItems: orderByItems, + } } - if onRight[0] == aliasRight { - hashJoinPlan.ProbeKey = append(hashJoinPlan.ProbeKey, onRight[1]) + if stmt.GroupBy != nil { + if tmpPlan, err = handleGroupBy(tmpPlan, stmt); err != nil { + return nil, errors.WithStack(err) + } + } else if analysis.hasAggregate { + tmpPlan = &dml.AggregatePlan{ + Plan: tmpPlan, + Fields: stmt.Select, + } } - //todo order by, limit, group by, having etc.. + // overwrite stmt limit x offset y. eg `select * from student offset 100 limit 5` will be + // `select * from student offset 0 limit 100+5` + originOffset, newLimit := overwriteLimit(stmt, &o.Args) + if stmt.Limit != nil { + tmpPlan = &dml.LimitPlan{ + ParentPlan: tmpPlan, + OriginOffset: originOffset, + OverwriteLimit: newLimit, + } + } - return hashJoinPlan, nil + tmpPlan = &dml.RenamePlan{ + Plan: hashJoinPlan, + RenameList: analysis.normalizedFields, + } + return tmpPlan, nil } func getSelectFlag(ru *rule.Rule, stmt *ast.SelectStatement) (flag uint32) { @@ -660,20 +796,42 @@ func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, tb s if len(tb) < 1 { tb = stmt.From[0].Source.(ast.TableName).Suffix() } - metadatas, err := proto.LoadSchemaLoader().Load(ctx, rcontext.Schema(ctx), []string{tb}) + + metadata, err := loadMetadataByTable(ctx, tb) if err != nil { return errors.WithStack(err) } - metadata := metadatas[tb] - if metadata == nil || len(metadata.ColumnNames) == 0 { - return errors.Errorf("optimize: cannot get metadata of `%s`.`%s`", rcontext.Schema(ctx), tb) - } selectElements := make([]ast.SelectElement, len(metadata.Columns)) for i, column := range metadata.ColumnNames { selectElements[i] = ast.NewSelectElementColumn([]string{column}, "") } - stmt.Select = selectElements + if stmt.HasJoin() { + joinTable := stmt.From[0].Joins[0].Target.Source.(ast.TableName).Suffix() + joinTableMetadata, err := loadMetadataByTable(ctx, joinTable) + if err != nil { + return errors.WithStack(err) + } + + for column := range joinTableMetadata.Columns { + selectElements = append(selectElements, ast.NewSelectElementColumn([]string{column}, "")) + } + } + + stmt.Select = selectElements return nil } + +func loadMetadataByTable(ctx context.Context, tb string) (*proto.TableMetadata, error) { + metadatas, err := proto.LoadSchemaLoader().Load(ctx, rcontext.Schema(ctx), []string{tb}) + if err != nil { + return nil, errors.WithStack(err) + } + + metadata := metadatas[tb] + if metadata == nil || len(metadata.ColumnNames) == 0 { + return nil, errors.Errorf("optimize: cannot get metadata of `%s`.`%s`", rcontext.Schema(ctx), tb) + } + return metadata, nil +} diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index f672e5b8..b5cfff18 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -25,6 +25,7 @@ import ( "github.com/arana-db/arana/pkg/mysql/rows" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" "github.com/arana-db/arana/pkg/runtime/plan" "github.com/arana-db/arana/third_party/base58" "github.com/cespare/xxhash/v2" @@ -36,9 +37,12 @@ type HashJoinPlan struct { BuildPlan proto.Plan ProbePlan proto.Plan - BuildKey []string - ProbeKey []string - hashArea map[string]proto.Row + BuildKey string + ProbeKey string + hashArea map[string]proto.Row + IsFilterProbeRow bool + + Stmt *ast.SelectStatement } func (h *HashJoinPlan) Type() proto.PlanType { @@ -82,7 +86,7 @@ func (h *HashJoinPlan) build(ctx context.Context, conn proto.VConn) (proto.Datas if err != nil { return nil, errors.WithStack(err) } - cn := h.BuildKey[0] + cn := h.BuildKey xh := xxhash.New() h.hashArea = make(map[string]proto.Row) // build map @@ -130,19 +134,23 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot return nil } - cn := h.ProbeKey[0] + cn := h.ProbeKey filterFunc := func(row proto.Row) bool { findRow := probeMapFunc(row, cn) + if !h.IsFilterProbeRow { + return true + } + return findRow != nil } buildFields, _ := buildDs.Fields() // aggregate fields aggregateFieldsFunc := func(fields []proto.Field) []proto.Field { - return append(buildFields, fields...) + return append(buildFields[:len(buildFields)-1], fields[:len(fields)-1]...) } - // todo left/right join + // todo, 需要注意输出的列的顺序与join表的顺序 // aggregate row fields, _ := ds.Fields() @@ -152,10 +160,20 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot matchRow := probeMapFunc(row, cn) buildDest := make([]proto.Value, len(buildFields)) - _ = matchRow.Scan(buildDest) + if matchRow != nil { + _ = matchRow.Scan(buildDest) + } else { + // set null row + if row.IsBinary() { + matchRow = rows.NewBinaryVirtualRow(buildFields, buildDest) + } else { + matchRow = rows.NewTextVirtualRow(buildFields, buildDest) + } + } - resFields := append(buildFields, fields...) - resDest := append(buildDest, dest...) + // 去掉最后一个on字段 + resFields := append(buildFields[:len(buildFields)-1], fields[:len(fields)-1]...) + resDest := append(buildDest[:len(buildDest)-1], dest[:len(dest)-1]...) var b bytes.Buffer if row.IsBinary() { From 88432a779f62ba1726bb70d000bf59f113896521 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sun, 16 Apr 2023 19:45:24 +0800 Subject: [PATCH 10/19] formatter import --- pkg/runtime/plan/dml/hash_join.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index b5cfff18..ebaf691a 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -20,6 +20,15 @@ package dml import ( "bytes" "context" + "io" +) + +import ( + "github.com/cespare/xxhash/v2" + "github.com/pkg/errors" +) + +import ( "github.com/arana-db/arana/pkg/dataset" "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/mysql/rows" @@ -28,9 +37,6 @@ import ( "github.com/arana-db/arana/pkg/runtime/ast" "github.com/arana-db/arana/pkg/runtime/plan" "github.com/arana-db/arana/third_party/base58" - "github.com/cespare/xxhash/v2" - "github.com/pkg/errors" - "io" ) type HashJoinPlan struct { From adf0ae86615d24599d5750128f8df5dbafaf1a5d Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sun, 16 Apr 2023 20:13:14 +0800 Subject: [PATCH 11/19] fix ci --- pkg/runtime/optimize/dml/select.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index 34de4669..0b1c713b 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -109,13 +109,14 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err } } + if stmt.HasJoin() { + return optimizeJoin(ctx, o, stmt) + } // overwrite stmt limit x offset y. eg `select * from student offset 100 limit 5` will be // `select * from student offset 0 limit 100+5` originOffset, newLimit := overwriteLimit(stmt, &o.Args) - if stmt.HasJoin() { - return optimizeJoin(ctx, o, stmt) - } + flag := getSelectFlag(o.Rule, stmt) if flag&_supported == 0 { return nil, errors.Errorf("unsupported sql: %s", rcontext.SQL(ctx)) @@ -676,10 +677,10 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt } } - // overwrite stmt limit x offset y. eg `select * from student offset 100 limit 5` will be - // `select * from student offset 0 limit 100+5` - originOffset, newLimit := overwriteLimit(stmt, &o.Args) if stmt.Limit != nil { + // overwrite stmt limit x offset y. eg `select * from student offset 100 limit 5` will be + // `select * from student offset 0 limit 100+5` + originOffset, newLimit := overwriteLimit(stmt, &o.Args) tmpPlan = &dml.LimitPlan{ ParentPlan: tmpPlan, OriginOffset: originOffset, @@ -687,6 +688,13 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt } } + if analysis.hasMapping { + tmpPlan = &dml.MappingPlan{ + Plan: tmpPlan, + Fields: stmt.Select, + } + } + tmpPlan = &dml.RenamePlan{ Plan: hashJoinPlan, RenameList: analysis.normalizedFields, From f5419b7101bd5aecd5712e8529ee4075c4fda065 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sun, 16 Apr 2023 20:18:10 +0800 Subject: [PATCH 12/19] fix ci --- pkg/runtime/optimize/dml/select.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index 0b1c713b..fc2dc117 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -696,7 +696,7 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt } tmpPlan = &dml.RenamePlan{ - Plan: hashJoinPlan, + Plan: tmpPlan, RenameList: analysis.normalizedFields, } return tmpPlan, nil From dd8bb778f1793104946998ac0fe2ee3411f090d7 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Thu, 25 May 2023 23:04:10 +0800 Subject: [PATCH 13/19] add ut --- pkg/runtime/optimize/dml/select.go | 19 ++- pkg/runtime/optimize/optimizer_test.go | 82 +++++++++++- pkg/runtime/optimize/shard_visitor_test.go | 17 ++- pkg/runtime/plan/dml/hash_join.go | 8 +- pkg/runtime/plan/dml/hash_join_test.go | 144 +++++++++++++++++++++ 5 files changed, 246 insertions(+), 24 deletions(-) create mode 100644 pkg/runtime/plan/dml/hash_join_test.go diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index fc2dc117..68d0ef6c 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -19,7 +19,6 @@ package dml import ( "context" - "github.com/arana-db/parser" "strings" ) @@ -40,6 +39,7 @@ import ( "github.com/arana-db/arana/pkg/runtime/optimize/dml/ext" "github.com/arana-db/arana/pkg/runtime/plan/dml" "github.com/arana-db/arana/pkg/util/log" + "github.com/arana-db/parser" ) const ( @@ -544,10 +544,10 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt for _, element := range selectColumn { e, ok := element.(*ast.SelectElementColumn) if ok { - for _, c := range metadata.ColumnNames { - if (aliasTb == e.Prefix() || actualTb == e.Prefix()) && c == e.Suffix() { - selectElements = append(selectElements, ast.NewSelectElementColumn([]string{c}, "")) - } + columnsMap := metadata.Columns + ColumnMeta, exist := columnsMap[e.Suffix()] + if (aliasTb == e.Prefix() || actualTb == e.Prefix()) && exist { + selectElements = append(selectElements, ast.NewSelectElementColumn([]string{ColumnMeta.Name}, "")) } } } @@ -571,7 +571,7 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt return nil, err } - optimizer, err = optimize.NewOptimizer(o.Rule, o.Hints, stmtNode, o.Args) + optimizer, err = optimize.NewOptimizer(o.Rule, nil, stmtNode, nil) if err != nil { return nil, err } @@ -600,15 +600,14 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt plan.ProbePlan = probePlan } - typ := join.Typ - if typ.String() == "INNER" { + if join.Typ == ast.InnerJoin { setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey) hashJoinPlan.IsFilterProbeRow = true } else { hashJoinPlan.IsFilterProbeRow = false - if typ.String() == "LEFT" { + if join.Typ == ast.LeftJoin { setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey) - } else if typ.String() == "RIGHT" { + } else if join.Typ == ast.RightJoin { setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey) } else { return nil, errors.New("not support Join Type") diff --git a/pkg/runtime/optimize/optimizer_test.go b/pkg/runtime/optimize/optimizer_test.go index df506c96..fb39757d 100644 --- a/pkg/runtime/optimize/optimizer_test.go +++ b/pkg/runtime/optimize/optimizer_test.go @@ -20,6 +20,10 @@ package optimize_test import ( "context" "fmt" + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" "strings" "testing" ) @@ -64,7 +68,7 @@ func TestOptimizer_OptimizeSelect(t *testing.T) { var ( sql = "select id, uid from student where uid in (?,?,?)" ctx = context.WithValue(context.Background(), proto.ContextKeyEnableLocalComputation{}, true) - ru = makeFakeRule(ctrl, 8) + ru = makeFakeRule(ctrl, "student", 8, nil) ) p := parser.New() @@ -81,6 +85,80 @@ func TestOptimizer_OptimizeSelect(t *testing.T) { _, _ = plan.ExecIn(ctx, conn) } +func TestOptimizer_OptimizeHashJoin(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + studentFields := []proto.Field{ + mysql.NewField("uid", consts.FieldTypeLongLong), + mysql.NewField("name", consts.FieldTypeString), + } + + salariesFields := []proto.Field{ + mysql.NewField("emp_no", consts.FieldTypeLongLong), + mysql.NewField("name", consts.FieldTypeString), + } + + conn := testdata.NewMockVConn(ctrl) + buildPlan := true + conn.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, db string, sql string, args ...interface{}) (proto.Result, error) { + t.Logf("fake query: db=%s, sql=%s, args=%v\n", db, sql, args) + + result := testdata.NewMockResult(ctrl) + fakeData := &dataset.VirtualDataset{} + if buildPlan { + fakeData.Columns = append(studentFields, mysql.NewField("uid", consts.FieldTypeLongLong)) + for i := int64(0); i < 8; i++ { + fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ + proto.NewValueInt64(i), + proto.NewValueString(fmt.Sprintf("fake-student-name-%d", i)), + proto.NewValueInt64(i), + })) + } + result.EXPECT().Dataset().Return(fakeData, nil).AnyTimes() + buildPlan = false + } else { + fakeData.Columns = append(salariesFields, mysql.NewField("emp_no", consts.FieldTypeLongLong)) + for i := int64(10); i > 3; i-- { + fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ + proto.NewValueInt64(i), + proto.NewValueString(fmt.Sprintf("fake-salaries-name-%d", i)), + proto.NewValueInt64(i), + })) + } + result.EXPECT().Dataset().Return(fakeData, nil).AnyTimes() + } + + return result, nil + }). + AnyTimes() + + var ( + sql = "select * from student join salaries on uid = emp_no" + ctx = context.WithValue(context.Background(), proto.ContextKeyEnableLocalComputation{}, true) + ru = makeFakeRule(ctrl, "student", 8, nil) + ) + + ru = makeFakeRule(ctrl, "salaries", 8, ru) + + p := parser.New() + stmt, _ := p.ParseOneStmt(sql, "", "") + opt, err := NewOptimizer(ru, nil, stmt, nil) + assert.NoError(t, err) + + vTable, _ := ru.VTable("student") + vTable.SetAllowFullScan(true) + + vTable2, _ := ru.VTable("salaries") + vTable2.SetAllowFullScan(true) + + plan, err := opt.Optimize(ctx) + assert.NoError(t, err) + + _, _ = plan.ExecIn(ctx, conn) +} + func TestOptimizer_OptimizeInsert(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -143,7 +221,7 @@ func TestOptimizer_OptimizeInsert(t *testing.T) { var ( ctx = context.Background() - ru = makeFakeRule(ctrl, 8) + ru = makeFakeRule(ctrl, "student", 8, nil) ) t.Run("sharding", func(t *testing.T) { diff --git a/pkg/runtime/optimize/shard_visitor_test.go b/pkg/runtime/optimize/shard_visitor_test.go index 04bd3d70..34736a16 100644 --- a/pkg/runtime/optimize/shard_visitor_test.go +++ b/pkg/runtime/optimize/shard_visitor_test.go @@ -44,7 +44,7 @@ func TestShardNG(t *testing.T) { defer ctrl.Finish() // test rule: student, uid % 8 - fakeRule := makeFakeRule(ctrl, 8) + fakeRule := makeFakeRule(ctrl, "student", 8, nil) type tt struct { sql string @@ -78,17 +78,20 @@ func TestShardNG(t *testing.T) { } } -func makeFakeRule(c *gomock.Controller, mod int) *rule.Rule { +func makeFakeRule(c *gomock.Controller, table string, mod int, ru *rule.Rule) *rule.Rule { var ( - ru rule.Rule tab rule.VTable topo rule.Topology ) + if ru == nil { + ru = &rule.Rule{} + } + topo.SetRender(func(_ int) string { return "fake_db" }, func(i int) string { - return fmt.Sprintf("student_%04d", i) + return fmt.Sprintf("%s_%04d", table, i) }) tables := make([]int, 0, mod) @@ -98,7 +101,7 @@ func makeFakeRule(c *gomock.Controller, mod int) *rule.Rule { topo.SetTopology(0, tables...) tab.SetTopology(&topo) - tab.SetName("student") + tab.SetName(table) computer := testdata.NewMockShardComputer(c) @@ -118,6 +121,6 @@ func makeFakeRule(c *gomock.Controller, mod int) *rule.Rule { sm.Computer = computer tab.SetShardMetadata("uid", nil, &sm) - ru.SetVTable("student", &tab) - return &ru + ru.SetVTable(table, &tab) + return ru } diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index ebaf691a..dced40d6 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -156,8 +156,6 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot return append(buildFields[:len(buildFields)-1], fields[:len(fields)-1]...) } - // todo, 需要注意输出的列的顺序与join表的顺序 - // aggregate row fields, _ := ds.Fields() transformFunc := func(row proto.Row) (proto.Row, error) { @@ -177,7 +175,7 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot } } - // 去掉最后一个on字段 + // remove 'ON' column resFields := append(buildFields[:len(buildFields)-1], fields[:len(fields)-1]...) resDest := append(buildDest[:len(buildDest)-1], dest[:len(dest)-1]...) @@ -189,7 +187,7 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot return nil, err } - br := mysql.NewBinaryRow(fields, b.Bytes()) + br := mysql.NewBinaryRow(resFields, b.Bytes()) return br, nil } else { newRow := rows.NewTextVirtualRow(resFields, resDest) @@ -198,7 +196,7 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot return nil, err } - return mysql.NewTextRow(fields, b.Bytes()), nil + return mysql.NewTextRow(resFields, b.Bytes()), nil } } diff --git a/pkg/runtime/plan/dml/hash_join_test.go b/pkg/runtime/plan/dml/hash_join_test.go new file mode 100644 index 00000000..848dd194 --- /dev/null +++ b/pkg/runtime/plan/dml/hash_join_test.go @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" + "fmt" + "io" + "testing" +) + +import ( + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/testdata" +) + +func TestHashJoinPlan(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + studentFields := []proto.Field{ + mysql.NewField("uid", consts.FieldTypeLongLong), + mysql.NewField("name", consts.FieldTypeString), + } + + salariesFields := []proto.Field{ + mysql.NewField("emp_no", consts.FieldTypeLongLong), + mysql.NewField("name", consts.FieldTypeString), + } + + buildPlan := true + conn := testdata.NewMockVConn(ctrl) + conn.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, db string, sql string, args ...interface{}) (proto.Result, error) { + t.Logf("fake query: db=%s, sql=%s, args=%v\n", db, sql, args) + + result := testdata.NewMockResult(ctrl) + fakeData := &dataset.VirtualDataset{} + if buildPlan { + fakeData.Columns = append(studentFields, mysql.NewField("uid", consts.FieldTypeLongLong)) + for i := int64(0); i < 8; i++ { + fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ + proto.NewValueInt64(i), + proto.NewValueString(fmt.Sprintf("fake-student-name-%d", i)), + proto.NewValueInt64(i), + })) + } + result.EXPECT().Dataset().Return(fakeData, nil).AnyTimes() + buildPlan = false + } else { + fakeData.Columns = append(salariesFields, mysql.NewField("emp_no", consts.FieldTypeLongLong)) + for i := int64(10); i > 3; i-- { + fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ + proto.NewValueInt64(i), + proto.NewValueString(fmt.Sprintf("fake-salaries-name-%d", i)), + proto.NewValueInt64(i), + })) + } + result.EXPECT().Dataset().Return(fakeData, nil).AnyTimes() + } + + return result, nil + }). + AnyTimes() + + var ( + sql1 = "SELECT * FROM student" // mock build plan + sql2 = "SELECT * FROM salaries" // mock probe plan + ctx = context.WithValue(context.Background(), proto.ContextKeyEnableLocalComputation{}, true) + ) + + _, stmt1, _ := ast.ParseSelect(sql1) + _, stmt2, _ := ast.ParseSelect(sql2) + + // sql: select * from student join salaries on uid = emp_no; + plan := &HashJoinPlan{ + BuildPlan: CompositePlan{ + []proto.Plan{ + &SimpleQueryPlan{ + Stmt: stmt1, + }, + }, + }, + ProbePlan: CompositePlan{ + []proto.Plan{ + &SimpleQueryPlan{ + Stmt: stmt2, + }, + }, + }, + IsFilterProbeRow: true, + BuildKey: "uid", + ProbeKey: "emp_no", + } + + res, err := plan.ExecIn(ctx, conn) + assert.NoError(t, err) + ds, _ := res.Dataset() + f, _ := ds.Fields() + + // expected field + assert.Equal(t, "uid", f[0].Name()) + assert.Equal(t, "name", f[1].Name()) + assert.Equal(t, "emp_no", f[2].Name()) + assert.Equal(t, "name", f[3].Name()) + for { + next, err := ds.Next() + if err == io.EOF { + break + } + row := next.(proto.Row) + dest := make([]proto.Value, len(f)) + _ = row.Scan(dest) + + // expected value: uid = emp_no + assert.Equal(t, dest[0], dest[2]) + } + +} From 57f584e657c185a031a7a48f08124f14f6eaa045 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sat, 27 May 2023 20:00:42 +0800 Subject: [PATCH 14/19] fix ut --- pkg/runtime/optimize/dml/select.go | 132 ++++++++------------- pkg/runtime/optimize/optimizer_test.go | 30 +++-- pkg/runtime/optimize/shard_visitor_test.go | 2 +- pkg/runtime/plan/dml/hash_join.go | 12 +- 4 files changed, 83 insertions(+), 93 deletions(-) diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index 68d0ef6c..604439f9 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -39,7 +39,6 @@ import ( "github.com/arana-db/arana/pkg/runtime/optimize/dml/ext" "github.com/arana-db/arana/pkg/runtime/plan/dml" "github.com/arana-db/arana/pkg/util/log" - "github.com/arana-db/parser" ) const ( @@ -124,7 +123,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err if flag&_bypass != 0 { if len(stmt.From) > 0 { - err := rewriteSelectStatement(ctx, stmt, stmt.From[0].Source.(ast.TableName).Suffix()) + err := rewriteSelectStatement(ctx, stmt, o) if err != nil { return nil, err } @@ -173,8 +172,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err } toSingle := func(db, tbl string) (proto.Plan, error) { - _, tb0, _ := vt.Topology().Smallest() - if err := rewriteSelectStatement(ctx, stmt, tb0); err != nil { + if err := rewriteSelectStatement(ctx, stmt, o); err != nil { return nil, err } ret := &dml.SimpleQueryPlan{ @@ -220,8 +218,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err return toSingle(db, tbl) } - _, tb, _ := vt.Topology().Smallest() - if err = rewriteSelectStatement(ctx, stmt, tb); err != nil { + if err = rewriteSelectStatement(ctx, stmt, o); err != nil { return nil, errors.WithStack(err) } @@ -417,8 +414,8 @@ func handleGroupBy(parentPlan proto.Plan, stmt *ast.SelectStatement) (proto.Plan // optimizeJoin ony support a join b in one db. // DEPRECATED: reimplement in the future func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectStatement) (proto.Plan, error) { - compute := func(tableSource *ast.TableSourceItem) (database, alias string, shardsMap map[string][]string, err error) { - table := tableSource.Source.(ast.TableName) + compute := func(tableSource *ast.TableSourceItem) (database, alias string, table ast.TableName, shards rule.DatabaseTables, err error) { + table = tableSource.Source.(ast.TableName) if table == nil { err = errors.New("must table, not statement or join node") return @@ -430,46 +427,35 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt alias = table.Suffix() } - shards, err := o.ComputeShards(ctx, table, nil, o.Args) + shards, err = o.ComputeShards(ctx, table, nil, o.Args) if err != nil { return } - - shardsMap = make(map[string][]string, len(shards)) - - // table no shard - if shards == nil { - shardsMap[database] = append(shardsMap[database], table.Suffix()) - return - } - - // table has shard - shardsMap = shards return } from := stmt.From[0] - dbLeft, aliasLeft, shardsLeft, err := compute(&from.TableSourceItem) + dbLeft, aliasLeft, tableLeft, shardsLeft, err := compute(&from.TableSourceItem) if err != nil { return nil, err } join := from.Joins[0] - dbRight, aliasRight, shardsRight, err := compute(join.Target) + dbRight, aliasRight, tableRight, shardsRight, err := compute(join.Target) if err != nil { return nil, err } // one db - if dbLeft == dbRight && len(shardsLeft) == 1 && len(shardsRight) == 1 { + if dbLeft == dbRight && shardsLeft == nil && shardsRight == nil { joinPan := &dml.SimpleJoinPlan{ Left: &dml.JoinTable{ - Tables: shardsLeft[dbLeft], + Tables: tableLeft, Alias: aliasLeft, }, Join: from.Joins[0], Right: &dml.JoinTable{ - Tables: shardsRight[dbRight], + Tables: tableRight, Alias: aliasRight, }, Stmt: o.Stmt.(*ast.SelectStatement), @@ -515,25 +501,17 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt }, }, } - table := tableSource.Source.(ast.TableName) - actualTb := table.Suffix() - aliasTb := tableSource.Alias - tb0 := actualTb - for _, tb := range shards { - if len(tb) > 1 { + if _, ok = stmt.Select[0].(*ast.SelectElementAll); !ok && len(stmt.Select) > 1 { + table := tableSource.Source.(ast.TableName) + actualTb := table.Suffix() + aliasTb := tableSource.Alias + + tb0 := actualTb + if shards != nil { vt := o.Rule.MustVTable(tb0) _, tb0, _ = vt.Topology().Smallest() - break } - } - if _, ok = stmt.Select[0].(*ast.SelectElementAll); ok && len(stmt.Select) == 1 { - if err = rewriteSelectStatement(ctx, selectStmt, tb0); err != nil { - return nil, err - } - - selectStmt.Select = append(selectStmt.Select, ast.NewSelectElementColumn([]string{onKey}, "")) - } else { metadata, err := loadMetadataByTable(ctx, tb0) if err != nil { return nil, err @@ -555,28 +533,19 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt selectStmt.Select = selectElements } - var ( - optimizer proto.Optimizer - plan proto.Plan - sb strings.Builder - ) - err = selectStmt.Restore(ast.RestoreDefault, &sb, nil) - if err != nil { - return nil, err - } - - p := parser.New() - stmtNode, err := p.ParseOneStmt(sb.String(), "", "") - if err != nil { - return nil, err + optimizer := &optimize.Optimizer{ + Rule: o.Rule, + Stmt: selectStmt, } + if _, ok = selectStmt.Select[0].(*ast.SelectElementAll); ok && len(selectStmt.Select) == 1 { + if err = rewriteSelectStatement(ctx, selectStmt, optimizer); err != nil { + return nil, err + } - optimizer, err = optimize.NewOptimizer(o.Rule, nil, stmtNode, nil) - if err != nil { - return nil, err + selectStmt.Select = append(selectStmt.Select, ast.NewSelectElementColumn([]string{onKey}, "")) } - plan, err = optimizeSelect(ctx, optimizer.(*optimize.Optimizer)) + plan, err := optimizeSelect(ctx, optimizer) if err != nil { return nil, err } @@ -618,14 +587,11 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt tmpPlan = hashJoinPlan var ( - analysis selectResult - scanner = newSelectScanner(stmt, o.Args) - tableName = from.Source.(ast.TableName) - vt = o.Rule.MustVTable(tableName.Suffix()) + analysis selectResult + scanner = newSelectScanner(stmt, o.Args) ) - _, tb, _ := vt.Topology().Smallest() - if err = rewriteSelectStatement(ctx, stmt, tb); err != nil { + if err = rewriteSelectStatement(ctx, stmt, o); err != nil { return nil, errors.WithStack(err) } @@ -787,7 +753,7 @@ func overwriteLimit(stmt *ast.SelectStatement, args *[]proto.Value) (originOffse return } -func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, tb string) error { +func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, o *optimize.Optimizer) error { // todo db 计算逻辑&tb shard 的计算逻辑 starExpand := false if len(stmt.Select) == 1 { @@ -800,33 +766,35 @@ func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, tb s return nil } - if len(tb) < 1 { - tb = stmt.From[0].Source.(ast.TableName).Suffix() + tbs := []ast.TableName{stmt.From[0].Source.(ast.TableName)} + for _, join := range stmt.From[0].Joins { + joinTable := join.Target.Source.(ast.TableName) + tbs = append(tbs, joinTable) } - metadata, err := loadMetadataByTable(ctx, tb) - if err != nil { - return errors.WithStack(err) - } + selectExpandElements := make([]ast.SelectElement, 0) + for _, t := range tbs { + shards, err := o.ComputeShards(ctx, t, nil, o.Args) + if err != nil { + return errors.WithStack(err) + } - selectElements := make([]ast.SelectElement, len(metadata.Columns)) - for i, column := range metadata.ColumnNames { - selectElements[i] = ast.NewSelectElementColumn([]string{column}, "") - } + tb0 := t.Suffix() + if shards != nil { + vt := o.Rule.MustVTable(tb0) + _, tb0, _ = vt.Topology().Smallest() + } - if stmt.HasJoin() { - joinTable := stmt.From[0].Joins[0].Target.Source.(ast.TableName).Suffix() - joinTableMetadata, err := loadMetadataByTable(ctx, joinTable) + metadata, err := loadMetadataByTable(ctx, tb0) if err != nil { return errors.WithStack(err) } - for column := range joinTableMetadata.Columns { - selectElements = append(selectElements, ast.NewSelectElementColumn([]string{column}, "")) + for _, column := range metadata.ColumnNames { + selectExpandElements = append(selectExpandElements, ast.NewSelectElementColumn([]string{column}, "")) } } - - stmt.Select = selectElements + stmt.Select = selectExpandElements return nil } diff --git a/pkg/runtime/optimize/optimizer_test.go b/pkg/runtime/optimize/optimizer_test.go index fb39757d..0049b5a1 100644 --- a/pkg/runtime/optimize/optimizer_test.go +++ b/pkg/runtime/optimize/optimizer_test.go @@ -91,12 +91,10 @@ func TestOptimizer_OptimizeHashJoin(t *testing.T) { studentFields := []proto.Field{ mysql.NewField("uid", consts.FieldTypeLongLong), - mysql.NewField("name", consts.FieldTypeString), } salariesFields := []proto.Field{ - mysql.NewField("emp_no", consts.FieldTypeLongLong), - mysql.NewField("name", consts.FieldTypeString), + mysql.NewField("uid", consts.FieldTypeLongLong), } conn := testdata.NewMockVConn(ctrl) @@ -112,18 +110,16 @@ func TestOptimizer_OptimizeHashJoin(t *testing.T) { for i := int64(0); i < 8; i++ { fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ proto.NewValueInt64(i), - proto.NewValueString(fmt.Sprintf("fake-student-name-%d", i)), proto.NewValueInt64(i), })) } result.EXPECT().Dataset().Return(fakeData, nil).AnyTimes() buildPlan = false } else { - fakeData.Columns = append(salariesFields, mysql.NewField("emp_no", consts.FieldTypeLongLong)) + fakeData.Columns = append(salariesFields, mysql.NewField("uid", consts.FieldTypeLongLong)) for i := int64(10); i > 3; i-- { fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ proto.NewValueInt64(i), - proto.NewValueString(fmt.Sprintf("fake-salaries-name-%d", i)), proto.NewValueInt64(i), })) } @@ -134,8 +130,28 @@ func TestOptimizer_OptimizeHashJoin(t *testing.T) { }). AnyTimes() + fakeData := make(map[string]*proto.TableMetadata) + // fake data + fakeData["student_0000"] = &proto.TableMetadata{ + Name: "student_0000", + Columns: map[string]*proto.ColumnMetadata{"uid": {}}, + ColumnNames: []string{"uid"}, + } + + fakeData["salaries_0000"] = &proto.TableMetadata{ + Name: "salaries_0000", + Columns: map[string]*proto.ColumnMetadata{"uid": {}}, + ColumnNames: []string{"uid"}, + } + loader := testdata.NewMockSchemaLoader(ctrl) + loader.EXPECT().Load(gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeData, nil).AnyTimes() + + oldLoader := proto.LoadSchemaLoader() + proto.RegisterSchemaLoader(loader) + defer proto.RegisterSchemaLoader(oldLoader) + var ( - sql = "select * from student join salaries on uid = emp_no" + sql = "select * from student join salaries on student.uid = salaries.uid" ctx = context.WithValue(context.Background(), proto.ContextKeyEnableLocalComputation{}, true) ru = makeFakeRule(ctrl, "student", 8, nil) ) diff --git a/pkg/runtime/optimize/shard_visitor_test.go b/pkg/runtime/optimize/shard_visitor_test.go index 34736a16..4753a7ea 100644 --- a/pkg/runtime/optimize/shard_visitor_test.go +++ b/pkg/runtime/optimize/shard_visitor_test.go @@ -114,7 +114,7 @@ func makeFakeRule(c *gomock.Controller, table string, mod int, ru *rule.Rule) *r } return n % mod, nil }). - MinTimes(1) + AnyTimes() var sm rule.ShardMetadata sm.Steps = 8 diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index dced40d6..61c1fe0d 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -118,7 +118,7 @@ func (h *HashJoinPlan) build(ctx context.Context, conn proto.VConn) (proto.Datas return ds, nil } -func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs proto.Dataset) (proto.Dataset, error) { +func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDataset proto.Dataset) (proto.Dataset, error) { res, err := h.queryAggregate(ctx, conn, h.ProbePlan) if err != nil { return nil, errors.WithStack(err) @@ -150,14 +150,20 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDs prot return findRow != nil } - buildFields, _ := buildDs.Fields() + buildFields, err := buildDataset.Fields() + if err != nil { + return nil, errors.WithStack(err) + } // aggregate fields aggregateFieldsFunc := func(fields []proto.Field) []proto.Field { return append(buildFields[:len(buildFields)-1], fields[:len(fields)-1]...) } // aggregate row - fields, _ := ds.Fields() + fields, err := ds.Fields() + if err != nil { + return nil, errors.WithStack(err) + } transformFunc := func(row proto.Row) (proto.Row, error) { dest := make([]proto.Value, len(fields)) _ = row.Scan(dest) From 104d29e4873a023743d3a862d51abf3c279b1a64 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Tue, 30 May 2023 22:03:08 +0800 Subject: [PATCH 15/19] fix ci --- pkg/runtime/optimize/dml/select.go | 1 + pkg/runtime/plan/dml/hash_join.go | 19 +++++++++++++++++-- pkg/runtime/plan/dml/hash_join_test.go | 3 +-- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index 604439f9..cb9d99d2 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -575,6 +575,7 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt } else { hashJoinPlan.IsFilterProbeRow = false if join.Typ == ast.LeftJoin { + hashJoinPlan.IsReversedColumn = true setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey) } else if join.Typ == ast.RightJoin { setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey) diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index 61c1fe0d..7decd5e7 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -47,6 +47,7 @@ type HashJoinPlan struct { ProbeKey string hashArea map[string]proto.Row IsFilterProbeRow bool + IsReversedColumn bool Stmt *ast.SelectStatement } @@ -156,6 +157,10 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDataset } // aggregate fields aggregateFieldsFunc := func(fields []proto.Field) []proto.Field { + if h.IsReversedColumn { + return append(fields[:len(fields)-1], buildFields[:len(buildFields)-1]...) + } + return append(buildFields[:len(buildFields)-1], fields[:len(fields)-1]...) } @@ -181,9 +186,19 @@ func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDataset } } + var ( + resFields []proto.Field + resDest []proto.Value + ) + // remove 'ON' column - resFields := append(buildFields[:len(buildFields)-1], fields[:len(fields)-1]...) - resDest := append(buildDest[:len(buildDest)-1], dest[:len(dest)-1]...) + if h.IsReversedColumn { + resFields = append(fields[:len(fields)-1], buildFields[:len(buildFields)-1]...) + resDest = append(dest[:len(dest)-1], buildDest[:len(buildDest)-1]...) + } else { + resFields = append(buildFields[:len(buildFields)-1], fields[:len(fields)-1]...) + resDest = append(buildDest[:len(buildDest)-1], dest[:len(dest)-1]...) + } var b bytes.Buffer if row.IsBinary() { diff --git a/pkg/runtime/plan/dml/hash_join_test.go b/pkg/runtime/plan/dml/hash_join_test.go index 848dd194..2b8adcb2 100644 --- a/pkg/runtime/plan/dml/hash_join_test.go +++ b/pkg/runtime/plan/dml/hash_join_test.go @@ -133,9 +133,8 @@ func TestHashJoinPlan(t *testing.T) { if err == io.EOF { break } - row := next.(proto.Row) dest := make([]proto.Value, len(f)) - _ = row.Scan(dest) + _ = next.Scan(dest) // expected value: uid = emp_no assert.Equal(t, dest[0], dest[2]) From 5a63e761fe8f3aae5c183007cab8897fbba6cbd8 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Fri, 9 Jun 2023 16:24:05 +0800 Subject: [PATCH 16/19] fix where --- pkg/runtime/ast/expression.go | 21 ++++ pkg/runtime/ast/expression_atom.go | 54 +++++++++ pkg/runtime/ast/predicate.go | 53 +++++++++ pkg/runtime/optimize/dml/select.go | 172 +++++++++++++++++++++++++++-- 4 files changed, 291 insertions(+), 9 deletions(-) diff --git a/pkg/runtime/ast/expression.go b/pkg/runtime/ast/expression.go index f2c245cf..6a40242f 100644 --- a/pkg/runtime/ast/expression.go +++ b/pkg/runtime/ast/expression.go @@ -48,6 +48,7 @@ type ExpressionNode interface { Node Restorer Mode() ExpressionMode + Clone() ExpressionNode } type LogicalExpressionNode struct { @@ -85,6 +86,14 @@ func (l *LogicalExpressionNode) Mode() ExpressionMode { return EmLogical } +func (l *LogicalExpressionNode) Clone() ExpressionNode { + return &LogicalExpressionNode{ + Op: l.Op, + Left: l.Left.Clone(), + Right: l.Right.Clone(), + } +} + type NotExpressionNode struct { E ExpressionNode } @@ -105,6 +114,12 @@ func (n *NotExpressionNode) Mode() ExpressionMode { return EmNot } +func (n *NotExpressionNode) Clone() ExpressionNode { + return &NotExpressionNode{ + E: n.E.Clone(), + } +} + type PredicateExpressionNode struct { P PredicateNode } @@ -123,3 +138,9 @@ func (a *PredicateExpressionNode) Restore(flag RestoreFlag, sb *strings.Builder, func (a *PredicateExpressionNode) Mode() ExpressionMode { return EmPredicate } + +func (a *PredicateExpressionNode) Clone() ExpressionNode { + return &PredicateExpressionNode{ + P: a.P.Clone(), + } +} diff --git a/pkg/runtime/ast/expression_atom.go b/pkg/runtime/ast/expression_atom.go index d6100b18..85aecf6f 100644 --- a/pkg/runtime/ast/expression_atom.go +++ b/pkg/runtime/ast/expression_atom.go @@ -60,6 +60,7 @@ type ExpressionAtom interface { Node Restorer phantom() expressionAtomPhantom + Clone() ExpressionAtom } type IntervalExpressionAtom struct { @@ -103,6 +104,13 @@ func (ie *IntervalExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (ie *IntervalExpressionAtom) Clone() ExpressionAtom { + return &IntervalExpressionAtom{ + Unit: ie.Unit, + Value: ie.Value.Clone(), + } +} + type SystemVariableExpressionAtom struct { Name string System bool @@ -145,6 +153,14 @@ func (sy *SystemVariableExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (sy *SystemVariableExpressionAtom) Clone() ExpressionAtom { + return &SystemVariableExpressionAtom{ + Name: sy.Name, + System: sy.System, + Global: sy.Global, + } +} + type UnaryExpressionAtom struct { Operator string Inner Node // ExpressionAtom or *BinaryComparisonPredicateNode @@ -184,6 +200,10 @@ func (u *UnaryExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (u *UnaryExpressionAtom) Clone() ExpressionAtom { + panic("implement me") +} + type ConstantExpressionAtom struct { Inner interface{} } @@ -201,6 +221,12 @@ func (c *ConstantExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (c *ConstantExpressionAtom) Clone() ExpressionAtom { + return &ConstantExpressionAtom{ + Inner: c.Inner, + } +} + func constant2string(value interface{}) string { switch v := value.(type) { case Null: @@ -299,6 +325,12 @@ func (c ColumnNameExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (c ColumnNameExpressionAtom) Clone() ExpressionAtom { + res := make(ColumnNameExpressionAtom, len(c)) + copy(res, c) + return res +} + type VariableExpressionAtom int func (v VariableExpressionAtom) Accept(visitor Visitor) (interface{}, error) { @@ -323,6 +355,10 @@ func (v VariableExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (v VariableExpressionAtom) Clone() ExpressionAtom { + return v +} + type MathExpressionAtom struct { Left ExpressionAtom Operator string @@ -357,6 +393,14 @@ func (m *MathExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (m *MathExpressionAtom) Clone() ExpressionAtom { + return &MathExpressionAtom{ + Left: m.Left.Clone(), + Operator: m.Operator, + Right: m.Right.Clone(), + } +} + type NestedExpressionAtom struct { First ExpressionNode } @@ -379,6 +423,12 @@ func (n *NestedExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (n *NestedExpressionAtom) Clone() ExpressionAtom { + return &NestedExpressionAtom{ + First: n.First.Clone(), + } +} + type FunctionCallExpressionAtom struct { F Node // *Function OR *AggrFunction OR *CaseWhenElseFunction OR *CastFunction } @@ -412,3 +462,7 @@ func (f *FunctionCallExpressionAtom) Restore(flag RestoreFlag, sb *strings.Build func (f *FunctionCallExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } + +func (f *FunctionCallExpressionAtom) Clone() ExpressionAtom { + panic("implement me") +} diff --git a/pkg/runtime/ast/predicate.go b/pkg/runtime/ast/predicate.go index a79e4109..265cc3f1 100644 --- a/pkg/runtime/ast/predicate.go +++ b/pkg/runtime/ast/predicate.go @@ -44,6 +44,7 @@ type PredicateNode interface { Node Restorer phantom() predicateNodePhantom + Clone() PredicateNode } type LikePredicateNode struct { @@ -81,6 +82,14 @@ func (l *LikePredicateNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } +func (l *LikePredicateNode) Clone() PredicateNode { + return &LikePredicateNode{ + Not: l.Not, + Left: l.Left.Clone(), + Right: l.Right.Clone(), + } +} + type RegexpPredicationNode struct { Left PredicateNode Right PredicateNode @@ -111,6 +120,14 @@ func (rp *RegexpPredicationNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } +func (rp *RegexpPredicationNode) Clone() PredicateNode { + return &RegexpPredicationNode{ + Left: rp.Left.Clone(), + Right: rp.Right.Clone(), + Not: rp.Not, + } +} + type BinaryComparisonPredicateNode struct { Left PredicateNode Right PredicateNode @@ -158,6 +175,14 @@ func (b *BinaryComparisonPredicateNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } +func (b *BinaryComparisonPredicateNode) Clone() PredicateNode { + return &BinaryComparisonPredicateNode{ + Left: b.Left.Clone(), + Right: b.Right.Clone(), + Op: b.Op, + } +} + type AtomPredicateNode struct { A ExpressionAtom } @@ -185,6 +210,12 @@ func (a *AtomPredicateNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } +func (a *AtomPredicateNode) Clone() PredicateNode { + return &AtomPredicateNode{ + A: a.A.Clone(), + } +} + type BetweenPredicateNode struct { Not bool Key PredicateNode @@ -222,6 +253,15 @@ func (b *BetweenPredicateNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } +func (b *BetweenPredicateNode) Clone() PredicateNode { + return &BetweenPredicateNode{ + Not: b.Not, + Key: b.Key.Clone(), + Left: b.Left.Clone(), + Right: b.Right.Clone(), + } +} + type InPredicateNode struct { Not bool P PredicateNode @@ -264,3 +304,16 @@ func (ip *InPredicateNode) Restore(flag RestoreFlag, sb *strings.Builder, args * func (ip *InPredicateNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } + +func (ip *InPredicateNode) Clone() PredicateNode { + e := make([]ExpressionNode, 0, len(ip.E)) + for _, node := range ip.E { + e = append(e, node.Clone()) + } + + return &InPredicateNode{ + Not: ip.Not, + P: ip.P.Clone(), + E: e, + } +} diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index cb9d99d2..5723bfcb 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -19,6 +19,7 @@ package dml import ( "context" + "github.com/arana-db/arana/pkg/runtime/cmp" "strings" ) @@ -501,17 +502,16 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt }, }, } + table := tableSource.Source.(ast.TableName) + actualTb := table.Suffix() + aliasTb := tableSource.Alias + tb0 := actualTb + if shards != nil { + vt := o.Rule.MustVTable(tb0) + _, tb0, _ = vt.Topology().Smallest() + } if _, ok = stmt.Select[0].(*ast.SelectElementAll); !ok && len(stmt.Select) > 1 { - table := tableSource.Source.(ast.TableName) - actualTb := table.Suffix() - aliasTb := tableSource.Alias - - tb0 := actualTb - if shards != nil { - vt := o.Rule.MustVTable(tb0) - _, tb0, _ = vt.Topology().Smallest() - } metadata, err := loadMetadataByTable(ctx, tb0) if err != nil { return nil, err @@ -533,6 +533,14 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt selectStmt.Select = selectElements } + if stmt.Where != nil { + selectStmt.Where = stmt.Where.Clone() + err := filterWhereByTable(ctx, selectStmt.Where, tb0, aliasTb) + if err != nil { + return nil, err + } + } + optimizer := &optimize.Optimizer{ Rule: o.Rule, Stmt: selectStmt, @@ -811,3 +819,149 @@ func loadMetadataByTable(ctx context.Context, tb string) (*proto.TableMetadata, } return metadata, nil } + +func filterWhereByTable(ctx context.Context, where ast.ExpressionNode, table string, alis string) error { + metadata, err := loadMetadataByTable(ctx, table) + if err != nil { + return errors.WithStack(err) + } + + if err = filterNodeByTable(where, metadata, alis); err != nil { + return errors.WithStack(err) + } + + return nil +} + +var replaceNode = &ast.BinaryComparisonPredicateNode{ + Left: &ast.AtomPredicateNode{A: &ast.ConstantExpressionAtom{Inner: 1}}, + Right: &ast.AtomPredicateNode{A: &ast.ConstantExpressionAtom{Inner: 1}}, + Op: cmp.Ceq, +} + +func filterNodeByTable(expNode ast.ExpressionNode, metadata *proto.TableMetadata, alis string) error { + predicateNode, ok := expNode.(*ast.PredicateExpressionNode) + if ok { + bcpn, bcOk := predicateNode.P.(*ast.BinaryComparisonPredicateNode) + if bcOk { + columnNode, ok := bcpn.Left.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if !ok { + return errors.New("invalid node") + } + if columnNode.Prefix() != "" { + if columnNode.Prefix() != metadata.Name && columnNode.Prefix() != alis { + predicateNode.P = replaceNode + } + } else { + _, ok := metadata.Columns[columnNode.Suffix()] + if !ok { + predicateNode.P = replaceNode + } + } + rightColumn, ok := bcpn.Right.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if ok { + if rightColumn.Prefix() != "" { + if rightColumn.Prefix() != metadata.Name && rightColumn.Prefix() != alis { + return errors.New("not support node") + } + } else { + _, ok := metadata.Columns[rightColumn.Suffix()] + if !ok { + return errors.New("not support node") + } + } + } + return nil + } + + lpn, likeOk := predicateNode.P.(*ast.LikePredicateNode) + if likeOk { + columnNode := lpn.Left.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if columnNode.Prefix() != "" { + if columnNode.Prefix() != metadata.Name && columnNode.Prefix() != alis { + predicateNode.P = replaceNode + } + } else { + _, ok := metadata.Columns[columnNode.Suffix()] + if !ok { + predicateNode.P = replaceNode + } + } + return nil + } + + ipn, inOk := predicateNode.P.(*ast.InPredicateNode) + if inOk { + columnNode, ok := ipn.P.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if !ok { + return errors.New("invalid node") + } + if columnNode.Prefix() != "" { + if columnNode.Prefix() != metadata.Name && columnNode.Prefix() != alis { + predicateNode.P = replaceNode + } + } else { + _, ok := metadata.Columns[columnNode.Suffix()] + if !ok { + predicateNode.P = replaceNode + } + } + return nil + } + + bpn, betweenOk := predicateNode.P.(*ast.BetweenPredicateNode) + if betweenOk { + columnNode, ok := bpn.Key.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if !ok { + return errors.New("invalid node") + } + if columnNode.Prefix() != "" { + if columnNode.Prefix() != metadata.Name && columnNode.Prefix() != alis { + predicateNode.P = replaceNode + } + } else { + _, ok := metadata.Columns[columnNode.Suffix()] + if !ok { + predicateNode.P = replaceNode + } + } + + //columnNode := bpn.Right.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + return nil + } + + rpn, regexpOk := predicateNode.P.(*ast.RegexpPredicationNode) + if regexpOk { + columnNode, ok := rpn.Left.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if !ok { + return errors.New("invalid node") + } + if columnNode.Prefix() != "" { + if columnNode.Prefix() != metadata.Name && columnNode.Prefix() != alis { + predicateNode.P = replaceNode + } + } else { + _, ok := metadata.Columns[columnNode.Suffix()] + if !ok { + predicateNode.P = replaceNode + } + } + return nil + } + + return errors.New("invalid node") + } + + node, ok := expNode.(*ast.LogicalExpressionNode) + if !ok { + return errors.New("invalid node") + } + + if err := filterNodeByTable(node.Left, metadata, alis); err != nil { + return err + } + if err := filterNodeByTable(node.Right, metadata, alis); err != nil { + return err + } + return nil +} From 4b5fde39553bc469650e554419bb419b9d34235f Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Fri, 9 Jun 2023 16:38:00 +0800 Subject: [PATCH 17/19] imports-formatter --- pkg/mysql/server.go | 2 +- pkg/runtime/optimize/dml/select.go | 2 +- pkg/runtime/optimize/optimizer_test.go | 8 ++++---- pkg/runtime/plan/dml/hash_join.go | 1 + pkg/runtime/plan/dml/hash_join_test.go | 1 + 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pkg/mysql/server.go b/pkg/mysql/server.go index c2aa1cd3..15012985 100644 --- a/pkg/mysql/server.go +++ b/pkg/mysql/server.go @@ -34,7 +34,6 @@ import ( import ( _ "github.com/arana-db/parser/test_driver" - uconfig "github.com/arana-db/arana/pkg/util/config" perrors "github.com/pkg/errors" "go.uber.org/atomic" @@ -46,6 +45,7 @@ import ( "github.com/arana-db/arana/pkg/mysql/errors" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/security" + uconfig "github.com/arana-db/arana/pkg/util/config" "github.com/arana-db/arana/pkg/util/log" ) diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index 5723bfcb..b9292237 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -19,7 +19,6 @@ package dml import ( "context" - "github.com/arana-db/arana/pkg/runtime/cmp" "strings" ) @@ -34,6 +33,7 @@ import ( "github.com/arana-db/arana/pkg/proto/hint" "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/cmp" rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/pkg/runtime/misc/extvalue" "github.com/arana-db/arana/pkg/runtime/optimize" diff --git a/pkg/runtime/optimize/optimizer_test.go b/pkg/runtime/optimize/optimizer_test.go index 0049b5a1..87b93b5e 100644 --- a/pkg/runtime/optimize/optimizer_test.go +++ b/pkg/runtime/optimize/optimizer_test.go @@ -20,10 +20,6 @@ package optimize_test import ( "context" "fmt" - consts "github.com/arana-db/arana/pkg/constants/mysql" - "github.com/arana-db/arana/pkg/dataset" - "github.com/arana-db/arana/pkg/mysql" - "github.com/arana-db/arana/pkg/mysql/rows" "strings" "testing" ) @@ -37,6 +33,10 @@ import ( ) import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/resultx" diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go index 7decd5e7..80deedab 100644 --- a/pkg/runtime/plan/dml/hash_join.go +++ b/pkg/runtime/plan/dml/hash_join.go @@ -25,6 +25,7 @@ import ( import ( "github.com/cespare/xxhash/v2" + "github.com/pkg/errors" ) diff --git a/pkg/runtime/plan/dml/hash_join_test.go b/pkg/runtime/plan/dml/hash_join_test.go index 2b8adcb2..b0a30147 100644 --- a/pkg/runtime/plan/dml/hash_join_test.go +++ b/pkg/runtime/plan/dml/hash_join_test.go @@ -26,6 +26,7 @@ import ( import ( "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" ) From 56dfc3859b0c8dd1b455685b6eb871b412f1713e Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Sat, 1 Jul 2023 12:12:33 +0800 Subject: [PATCH 18/19] fix ut --- pkg/runtime/optimize/optimizer_test.go | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/pkg/runtime/optimize/optimizer_test.go b/pkg/runtime/optimize/optimizer_test.go index 87b93b5e..411e91e9 100644 --- a/pkg/runtime/optimize/optimizer_test.go +++ b/pkg/runtime/optimize/optimizer_test.go @@ -36,7 +36,6 @@ import ( consts "github.com/arana-db/arana/pkg/constants/mysql" "github.com/arana-db/arana/pkg/dataset" "github.com/arana-db/arana/pkg/mysql" - "github.com/arana-db/arana/pkg/mysql/rows" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/resultx" @@ -107,28 +106,15 @@ func TestOptimizer_OptimizeHashJoin(t *testing.T) { fakeData := &dataset.VirtualDataset{} if buildPlan { fakeData.Columns = append(studentFields, mysql.NewField("uid", consts.FieldTypeLongLong)) - for i := int64(0); i < 8; i++ { - fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ - proto.NewValueInt64(i), - proto.NewValueInt64(i), - })) - } result.EXPECT().Dataset().Return(fakeData, nil).AnyTimes() buildPlan = false } else { fakeData.Columns = append(salariesFields, mysql.NewField("uid", consts.FieldTypeLongLong)) - for i := int64(10); i > 3; i-- { - fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ - proto.NewValueInt64(i), - proto.NewValueInt64(i), - })) - } result.EXPECT().Dataset().Return(fakeData, nil).AnyTimes() } return result, nil - }). - AnyTimes() + }).MinTimes(2) fakeData := make(map[string]*proto.TableMetadata) // fake data From 3ab22ad28b54e38cab2cf2697b614ecd1333a745 Mon Sep 17 00:00:00 2001 From: huangwenkang <642380437@qq> Date: Wed, 5 Jul 2023 21:17:36 +0800 Subject: [PATCH 19/19] fix ut --- pkg/runtime/ast/expression.go | 2 +- pkg/runtime/optimize/shard_visitor_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/runtime/ast/expression.go b/pkg/runtime/ast/expression.go index d9dce765..cede33d3 100644 --- a/pkg/runtime/ast/expression.go +++ b/pkg/runtime/ast/expression.go @@ -81,7 +81,7 @@ func (l *LogicalExpressionNode) Mode() ExpressionMode { func (l *LogicalExpressionNode) Clone() ExpressionNode { return &LogicalExpressionNode{ - Op: l.Op, + Or: l.Or, Left: l.Left.Clone(), Right: l.Right.Clone(), } diff --git a/pkg/runtime/optimize/shard_visitor_test.go b/pkg/runtime/optimize/shard_visitor_test.go index f9a077ab..0a14c60f 100644 --- a/pkg/runtime/optimize/shard_visitor_test.go +++ b/pkg/runtime/optimize/shard_visitor_test.go @@ -148,7 +148,7 @@ func makeFakeRule(c *gomock.Controller, table string, mod int, ru *rule.Rule) *r Table: sm, }) - ru.SetVTable("student", &tab) - return &ru + ru.SetVTable(table, &tab) + return ru }