diff --git a/executor/builder.go b/executor/builder.go index 2104b00d58bb4..fee88b522a325 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -788,9 +788,16 @@ func (b *executorBuilder) buildMergeJoin(v *plannercore.PhysicalMergeJoin) Execu e := &MergeJoinExec{ stmtCtx: b.ctx.GetSessionVars().StmtCtx, baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), leftExec, rightExec), - joiner: newJoiner(b.ctx, v.JoinType, v.JoinType == plannercore.RightOuterJoin, - defaultValues, v.OtherConditions, - leftExec.retTypes(), rightExec.retTypes()), + compareFuncs: v.CompareFuncs, + joiner: newJoiner( + b.ctx, + v.JoinType, + v.JoinType == plannercore.RightOuterJoin, + defaultValues, + v.OtherConditions, + leftExec.retTypes(), + rightExec.retTypes(), + ), } leftKeys := v.LeftKeys diff --git a/executor/merge_join.go b/executor/merge_join.go index a02f2684c1d26..8e3787ba3b427 100644 --- a/executor/merge_join.go +++ b/executor/merge_join.go @@ -35,7 +35,7 @@ type MergeJoinExec struct { baseExecutor stmtCtx *stmtctx.StatementContext - compareFuncs []chunk.CompareFunc + compareFuncs []expression.CompareFunc joiner joiner prepared bool @@ -75,7 +75,7 @@ type mergeJoinInnerTable struct { // for chunk executions sameKeyRows []chunk.Row - compareFuncs []chunk.CompareFunc + keyCmpFuncs []chunk.CompareFunc firstRow4Key chunk.Row curRow chunk.Row curResult *chunk.Chunk @@ -99,9 +99,9 @@ func (t *mergeJoinInnerTable) init(ctx context.Context, chk4Reader *chunk.Chunk) t.resultQueue = append(t.resultQueue, chk4Reader) t.memTracker.Consume(chk4Reader.MemoryUsage()) t.firstRow4Key, err = t.nextRow() - t.compareFuncs = make([]chunk.CompareFunc, 0, len(t.joinKeys)) + t.keyCmpFuncs = make([]chunk.CompareFunc, 0, len(t.joinKeys)) for i := range t.joinKeys { - t.compareFuncs = append(t.compareFuncs, chunk.GetCompareFunc(t.joinKeys[i].RetType)) + t.keyCmpFuncs = append(t.keyCmpFuncs, chunk.GetCompareFunc(t.joinKeys[i].RetType)) } return errors.Trace(err) } @@ -123,7 +123,7 @@ func (t *mergeJoinInnerTable) rowsWithSameKey() ([]chunk.Row, error) { t.firstRow4Key = t.curIter.End() return t.sameKeyRows, errors.Trace(err) } - compareResult := compareChunkRow(t.compareFuncs, selectedRow, t.firstRow4Key, t.joinKeys, t.joinKeys) + compareResult := compareChunkRow(t.keyCmpFuncs, selectedRow, t.firstRow4Key, t.joinKeys, t.joinKeys) if compareResult == 0 { t.sameKeyRows = append(t.sameKeyRows, selectedRow) } else { @@ -256,7 +256,6 @@ func (e *MergeJoinExec) prepare(ctx context.Context, chk *chunk.Chunk) error { return errors.Trace(err) } - e.compareFuncs = e.innerTable.compareFuncs e.prepared = true return nil } @@ -294,7 +293,10 @@ func (e *MergeJoinExec) joinToChunk(ctx context.Context, chk *chunk.Chunk) (hasM cmpResult := -1 if e.outerTable.selected[e.outerTable.row.Idx()] && len(e.innerRows) > 0 { - cmpResult = compareChunkRow(e.compareFuncs, e.outerTable.row, e.innerRows[0], e.outerTable.keys, e.innerTable.joinKeys) + cmpResult, err = e.compare(e.outerTable.row, e.innerIter4Row.Current()) + if err != nil { + return false, err + } } if cmpResult > 0 { @@ -340,6 +342,22 @@ func (e *MergeJoinExec) joinToChunk(ctx context.Context, chk *chunk.Chunk) (hasM } } +func (e *MergeJoinExec) compare(outerRow, innerRow chunk.Row) (int, error) { + outerJoinKeys := e.outerTable.keys + innerJoinKeys := e.innerTable.joinKeys + for i := range outerJoinKeys { + cmp, _, err := e.compareFuncs[i](e.ctx, outerJoinKeys[i], innerJoinKeys[i], outerRow, innerRow) + if err != nil { + return 0, err + } + + if cmp != 0 { + return int(cmp), nil + } + } + return 0, nil +} + // fetchNextInnerRows fetches the next join group, within which all the rows // have the same join key, from the inner table. func (e *MergeJoinExec) fetchNextInnerRows() (err error) { diff --git a/executor/merge_join_test.go b/executor/merge_join_test.go index 30a7fdc93f382..913b9460d4856 100644 --- a/executor/merge_join_test.go +++ b/executor/merge_join_test.go @@ -396,3 +396,36 @@ func (s *testSuite) Test3WaysMergeJoin(c *C) { result = checkPlanAndRun(tk, c, plan3, "select /*+ TIDB_SMJ(t1,t2,t3) */ * from t1 right outer join t2 on t1.c1 = t2.c1 join t3 on t1.c1 = t3.c1 order by 1") result.Check(testkit.Rows("2 2 2 3 2 4", "3 3 3 4 3 10")) } + +func (s *testSuite) TestMergeJoinDifferentTypes(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`drop table if exists t1;`) + tk.MustExec(`drop table if exists t2;`) + tk.MustExec(`create table t1(a bigint, b bit(1), index idx_a(a));`) + tk.MustExec(`create table t2(a bit(1) not null, b bit(1), index idx_a(a));`) + tk.MustExec(`insert into t1 values(1, 1);`) + tk.MustExec(`insert into t2 values(1, 1);`) + tk.MustQuery(`select hex(t1.a), hex(t2.a) from t1 inner join t2 on t1.a=t2.a;`).Check(testkit.Rows(`1 1`)) + + tk.MustExec(`drop table if exists t1;`) + tk.MustExec(`drop table if exists t2;`) + tk.MustExec(`create table t1(a float, b double, index idx_a(a));`) + tk.MustExec(`create table t2(a double not null, b double, index idx_a(a));`) + tk.MustExec(`insert into t1 values(1, 1);`) + tk.MustExec(`insert into t2 values(1, 1);`) + tk.MustQuery(`select t1.a, t2.a from t1 inner join t2 on t1.a=t2.a;`).Check(testkit.Rows(`1 1`)) + + tk.MustExec(`drop table if exists t1;`) + tk.MustExec(`drop table if exists t2;`) + tk.MustExec(`create table t1(a bigint signed, b bigint, index idx_a(a));`) + tk.MustExec(`create table t2(a bigint unsigned, b bigint, index idx_a(a));`) + tk.MustExec(`insert into t1 values(-1, 0), (-1, 0), (0, 0), (0, 0), (pow(2, 63), 0), (pow(2, 63), 0);`) + tk.MustExec(`insert into t2 values(18446744073709551615, 0), (18446744073709551615, 0), (0, 0), (0, 0), (pow(2, 63), 0), (pow(2, 63), 0);`) + tk.MustQuery(`select t1.a, t2.a from t1 join t2 on t1.a=t2.a order by t1.a;`).Check(testkit.Rows( + `0 0`, + `0 0`, + `0 0`, + `0 0`, + )) +} diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 5832af7cd7703..f5cafd3729cd7 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -1383,7 +1383,7 @@ func (b *builtinLTIntSig) Clone() builtinFunc { } func (b *builtinLTIntSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLT(compareInt(b.ctx, b.args, row)) + return resOfLT(CompareInt(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLTRealSig struct { @@ -1397,7 +1397,7 @@ func (b *builtinLTRealSig) Clone() builtinFunc { } func (b *builtinLTRealSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLT(compareReal(b.ctx, b.args, row)) + return resOfLT(CompareReal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLTDecimalSig struct { @@ -1411,7 +1411,7 @@ func (b *builtinLTDecimalSig) Clone() builtinFunc { } func (b *builtinLTDecimalSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLT(compareDecimal(b.ctx, b.args, row)) + return resOfLT(CompareDecimal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLTStringSig struct { @@ -1425,7 +1425,7 @@ func (b *builtinLTStringSig) Clone() builtinFunc { } func (b *builtinLTStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLT(compareString(b.args, row, b.ctx)) + return resOfLT(CompareString(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLTDurationSig struct { @@ -1439,7 +1439,7 @@ func (b *builtinLTDurationSig) Clone() builtinFunc { } func (b *builtinLTDurationSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLT(compareDuration(b.args, row, b.ctx)) + return resOfLT(CompareDuration(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLTTimeSig struct { @@ -1453,7 +1453,7 @@ func (b *builtinLTTimeSig) Clone() builtinFunc { } func (b *builtinLTTimeSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLT(compareTime(b.ctx, b.args, row)) + return resOfLT(CompareTime(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLTJSONSig struct { @@ -1467,7 +1467,7 @@ func (b *builtinLTJSONSig) Clone() builtinFunc { } func (b *builtinLTJSONSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLT(compareJSON(b.ctx, b.args, row)) + return resOfLT(CompareJSON(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLEIntSig struct { @@ -1481,7 +1481,7 @@ func (b *builtinLEIntSig) Clone() builtinFunc { } func (b *builtinLEIntSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLE(compareInt(b.ctx, b.args, row)) + return resOfLE(CompareInt(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLERealSig struct { @@ -1495,7 +1495,7 @@ func (b *builtinLERealSig) Clone() builtinFunc { } func (b *builtinLERealSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLE(compareReal(b.ctx, b.args, row)) + return resOfLE(CompareReal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLEDecimalSig struct { @@ -1509,7 +1509,7 @@ func (b *builtinLEDecimalSig) Clone() builtinFunc { } func (b *builtinLEDecimalSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLE(compareDecimal(b.ctx, b.args, row)) + return resOfLE(CompareDecimal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLEStringSig struct { @@ -1523,7 +1523,7 @@ func (b *builtinLEStringSig) Clone() builtinFunc { } func (b *builtinLEStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLE(compareString(b.args, row, b.ctx)) + return resOfLE(CompareString(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLEDurationSig struct { @@ -1537,7 +1537,7 @@ func (b *builtinLEDurationSig) Clone() builtinFunc { } func (b *builtinLEDurationSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLE(compareDuration(b.args, row, b.ctx)) + return resOfLE(CompareDuration(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLETimeSig struct { @@ -1551,7 +1551,7 @@ func (b *builtinLETimeSig) Clone() builtinFunc { } func (b *builtinLETimeSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLE(compareTime(b.ctx, b.args, row)) + return resOfLE(CompareTime(b.ctx, b.args[0], b.args[1], row, row)) } type builtinLEJSONSig struct { @@ -1565,7 +1565,7 @@ func (b *builtinLEJSONSig) Clone() builtinFunc { } func (b *builtinLEJSONSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfLE(compareJSON(b.ctx, b.args, row)) + return resOfLE(CompareJSON(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGTIntSig struct { @@ -1579,7 +1579,7 @@ func (b *builtinGTIntSig) Clone() builtinFunc { } func (b *builtinGTIntSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGT(compareInt(b.ctx, b.args, row)) + return resOfGT(CompareInt(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGTRealSig struct { @@ -1593,7 +1593,7 @@ func (b *builtinGTRealSig) Clone() builtinFunc { } func (b *builtinGTRealSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGT(compareReal(b.ctx, b.args, row)) + return resOfGT(CompareReal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGTDecimalSig struct { @@ -1607,7 +1607,7 @@ func (b *builtinGTDecimalSig) Clone() builtinFunc { } func (b *builtinGTDecimalSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGT(compareDecimal(b.ctx, b.args, row)) + return resOfGT(CompareDecimal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGTStringSig struct { @@ -1621,7 +1621,7 @@ func (b *builtinGTStringSig) Clone() builtinFunc { } func (b *builtinGTStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGT(compareString(b.args, row, b.ctx)) + return resOfGT(CompareString(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGTDurationSig struct { @@ -1635,7 +1635,7 @@ func (b *builtinGTDurationSig) Clone() builtinFunc { } func (b *builtinGTDurationSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGT(compareDuration(b.args, row, b.ctx)) + return resOfGT(CompareDuration(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGTTimeSig struct { @@ -1649,7 +1649,7 @@ func (b *builtinGTTimeSig) Clone() builtinFunc { } func (b *builtinGTTimeSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGT(compareTime(b.ctx, b.args, row)) + return resOfGT(CompareTime(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGTJSONSig struct { @@ -1663,7 +1663,7 @@ func (b *builtinGTJSONSig) Clone() builtinFunc { } func (b *builtinGTJSONSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGT(compareJSON(b.ctx, b.args, row)) + return resOfGT(CompareJSON(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGEIntSig struct { @@ -1677,7 +1677,7 @@ func (b *builtinGEIntSig) Clone() builtinFunc { } func (b *builtinGEIntSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGE(compareInt(b.ctx, b.args, row)) + return resOfGE(CompareInt(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGERealSig struct { @@ -1691,7 +1691,7 @@ func (b *builtinGERealSig) Clone() builtinFunc { } func (b *builtinGERealSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGE(compareReal(b.ctx, b.args, row)) + return resOfGE(CompareReal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGEDecimalSig struct { @@ -1705,7 +1705,7 @@ func (b *builtinGEDecimalSig) Clone() builtinFunc { } func (b *builtinGEDecimalSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGE(compareDecimal(b.ctx, b.args, row)) + return resOfGE(CompareDecimal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGEStringSig struct { @@ -1719,7 +1719,7 @@ func (b *builtinGEStringSig) Clone() builtinFunc { } func (b *builtinGEStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGE(compareString(b.args, row, b.ctx)) + return resOfGE(CompareString(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGEDurationSig struct { @@ -1733,7 +1733,7 @@ func (b *builtinGEDurationSig) Clone() builtinFunc { } func (b *builtinGEDurationSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGE(compareDuration(b.args, row, b.ctx)) + return resOfGE(CompareDuration(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGETimeSig struct { @@ -1747,7 +1747,7 @@ func (b *builtinGETimeSig) Clone() builtinFunc { } func (b *builtinGETimeSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGE(compareTime(b.ctx, b.args, row)) + return resOfGE(CompareTime(b.ctx, b.args[0], b.args[1], row, row)) } type builtinGEJSONSig struct { @@ -1761,7 +1761,7 @@ func (b *builtinGEJSONSig) Clone() builtinFunc { } func (b *builtinGEJSONSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfGE(compareJSON(b.ctx, b.args, row)) + return resOfGE(CompareJSON(b.ctx, b.args[0], b.args[1], row, row)) } type builtinEQIntSig struct { @@ -1775,7 +1775,7 @@ func (b *builtinEQIntSig) Clone() builtinFunc { } func (b *builtinEQIntSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfEQ(compareInt(b.ctx, b.args, row)) + return resOfEQ(CompareInt(b.ctx, b.args[0], b.args[1], row, row)) } type builtinEQRealSig struct { @@ -1789,7 +1789,7 @@ func (b *builtinEQRealSig) Clone() builtinFunc { } func (b *builtinEQRealSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfEQ(compareReal(b.ctx, b.args, row)) + return resOfEQ(CompareReal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinEQDecimalSig struct { @@ -1803,7 +1803,7 @@ func (b *builtinEQDecimalSig) Clone() builtinFunc { } func (b *builtinEQDecimalSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfEQ(compareDecimal(b.ctx, b.args, row)) + return resOfEQ(CompareDecimal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinEQStringSig struct { @@ -1817,7 +1817,7 @@ func (b *builtinEQStringSig) Clone() builtinFunc { } func (b *builtinEQStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfEQ(compareString(b.args, row, b.ctx)) + return resOfEQ(CompareString(b.ctx, b.args[0], b.args[1], row, row)) } type builtinEQDurationSig struct { @@ -1831,7 +1831,7 @@ func (b *builtinEQDurationSig) Clone() builtinFunc { } func (b *builtinEQDurationSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfEQ(compareDuration(b.args, row, b.ctx)) + return resOfEQ(CompareDuration(b.ctx, b.args[0], b.args[1], row, row)) } type builtinEQTimeSig struct { @@ -1845,7 +1845,7 @@ func (b *builtinEQTimeSig) Clone() builtinFunc { } func (b *builtinEQTimeSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfEQ(compareTime(b.ctx, b.args, row)) + return resOfEQ(CompareTime(b.ctx, b.args[0], b.args[1], row, row)) } type builtinEQJSONSig struct { @@ -1859,7 +1859,7 @@ func (b *builtinEQJSONSig) Clone() builtinFunc { } func (b *builtinEQJSONSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfEQ(compareJSON(b.ctx, b.args, row)) + return resOfEQ(CompareJSON(b.ctx, b.args[0], b.args[1], row, row)) } type builtinNEIntSig struct { @@ -1873,7 +1873,7 @@ func (b *builtinNEIntSig) Clone() builtinFunc { } func (b *builtinNEIntSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfNE(compareInt(b.ctx, b.args, row)) + return resOfNE(CompareInt(b.ctx, b.args[0], b.args[1], row, row)) } type builtinNERealSig struct { @@ -1887,7 +1887,7 @@ func (b *builtinNERealSig) Clone() builtinFunc { } func (b *builtinNERealSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfNE(compareReal(b.ctx, b.args, row)) + return resOfNE(CompareReal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinNEDecimalSig struct { @@ -1901,7 +1901,7 @@ func (b *builtinNEDecimalSig) Clone() builtinFunc { } func (b *builtinNEDecimalSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfNE(compareDecimal(b.ctx, b.args, row)) + return resOfNE(CompareDecimal(b.ctx, b.args[0], b.args[1], row, row)) } type builtinNEStringSig struct { @@ -1915,7 +1915,7 @@ func (b *builtinNEStringSig) Clone() builtinFunc { } func (b *builtinNEStringSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfNE(compareString(b.args, row, b.ctx)) + return resOfNE(CompareString(b.ctx, b.args[0], b.args[1], row, row)) } type builtinNEDurationSig struct { @@ -1929,7 +1929,7 @@ func (b *builtinNEDurationSig) Clone() builtinFunc { } func (b *builtinNEDurationSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfNE(compareDuration(b.args, row, b.ctx)) + return resOfNE(CompareDuration(b.ctx, b.args[0], b.args[1], row, row)) } type builtinNETimeSig struct { @@ -1943,7 +1943,7 @@ func (b *builtinNETimeSig) Clone() builtinFunc { } func (b *builtinNETimeSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfNE(compareTime(b.ctx, b.args, row)) + return resOfNE(CompareTime(b.ctx, b.args[0], b.args[1], row, row)) } type builtinNEJSONSig struct { @@ -1957,7 +1957,7 @@ func (b *builtinNEJSONSig) Clone() builtinFunc { } func (b *builtinNEJSONSig) evalInt(row chunk.Row) (val int64, isNull bool, err error) { - return resOfNE(compareJSON(b.ctx, b.args, row)) + return resOfNE(CompareJSON(b.ctx, b.args[0], b.args[1], row, row)) } type builtinNullEQIntSig struct { @@ -2269,16 +2269,41 @@ func resOfNE(val int64, isNull bool, err error) (int64, bool, error) { return val, false, nil } -func compareInt(ctx sessionctx.Context, args []Expression, row chunk.Row) (val int64, isNull bool, err error) { - arg0, isNull0, err := args[0].EvalInt(ctx, row) - if isNull0 || err != nil { - return 0, isNull0, errors.Trace(err) +// compareNull compares null values based on the following rules. +// 1. NULL is considered to be equal to NULL +// 2. NULL is considered to be smaller than a non-NULL value. +// NOTE: (lhsIsNull == true) or (rhsIsNull == true) is required. +func compareNull(lhsIsNull, rhsIsNull bool) int64 { + if lhsIsNull && rhsIsNull { + return 0 } - arg1, isNull1, err := args[1].EvalInt(ctx, row) - if isNull1 || err != nil { - return 0, isNull1, errors.Trace(err) + if lhsIsNull { + return -1 + } + return 1 +} + +// CompareFunc defines the compare function prototype. +type CompareFunc = func(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) + +// CompareInt compares two integers. +func CompareInt(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) { + arg0, isNull0, err := lhsArg.EvalInt(sctx, lhsRow) + if err != nil { + return 0, true, err } - isUnsigned0, isUnsigned1 := mysql.HasUnsignedFlag(args[0].GetType().Flag), mysql.HasUnsignedFlag(args[1].GetType().Flag) + + arg1, isNull1, err := rhsArg.EvalInt(sctx, rhsRow) + if err != nil { + return 0, true, err + } + + // compare null values. + if isNull0 || isNull1 { + return compareNull(isNull0, isNull1), true, nil + } + + isUnsigned0, isUnsigned1 := mysql.HasUnsignedFlag(lhsArg.GetType().Flag), mysql.HasUnsignedFlag(rhsArg.GetType().Flag) var res int switch { case isUnsigned0 && isUnsigned1: @@ -2301,77 +2326,110 @@ func compareInt(ctx sessionctx.Context, args []Expression, row chunk.Row) (val i return int64(res), false, nil } -func compareString(args []Expression, row chunk.Row, ctx sessionctx.Context) (val int64, isNull bool, err error) { - arg0, isNull0, err := args[0].EvalString(ctx, row) - if isNull0 || err != nil { - return 0, isNull0, errors.Trace(err) +// CompareString compares two strings. +func CompareString(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) { + arg0, isNull0, err := lhsArg.EvalString(sctx, lhsRow) + if err != nil { + return 0, true, err } - arg1, isNull1, err := args[1].EvalString(ctx, row) - if isNull1 || err != nil { - return 0, isNull1, errors.Trace(err) + + arg1, isNull1, err := rhsArg.EvalString(sctx, rhsRow) + if err != nil { + return 0, true, err + } + + if isNull0 || isNull1 { + return compareNull(isNull0, isNull1), true, nil } return int64(types.CompareString(arg0, arg1)), false, nil } -func compareReal(ctx sessionctx.Context, args []Expression, row chunk.Row) (val int64, isNull bool, err error) { - arg0, isNull0, err := args[0].EvalReal(ctx, row) - if isNull0 || err != nil { - return 0, isNull0, errors.Trace(err) +// CompareReal compares two float-point values. +func CompareReal(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) { + arg0, isNull0, err := lhsArg.EvalReal(sctx, lhsRow) + if err != nil { + return 0, true, err } - arg1, isNull1, err := args[1].EvalReal(ctx, row) - if isNull1 || err != nil { - return 0, isNull1, errors.Trace(err) + + arg1, isNull1, err := rhsArg.EvalReal(sctx, rhsRow) + if err != nil { + return 0, true, err + } + + if isNull0 || isNull1 { + return compareNull(isNull0, isNull1), true, nil } return int64(types.CompareFloat64(arg0, arg1)), false, nil } -func compareDecimal(ctx sessionctx.Context, args []Expression, row chunk.Row) (val int64, isNull bool, err error) { - arg0, isNull0, err := args[0].EvalDecimal(ctx, row) - if isNull0 || err != nil { - return 0, isNull0, errors.Trace(err) +// CompareDecimal compares two decimals. +func CompareDecimal(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) { + arg0, isNull0, err := lhsArg.EvalDecimal(sctx, lhsRow) + if err != nil { + return 0, true, err } - arg1, isNull1, err := args[1].EvalDecimal(ctx, row) + + arg1, isNull1, err := rhsArg.EvalDecimal(sctx, rhsRow) if err != nil { return 0, true, errors.Trace(err) } - if isNull1 || err != nil { - return 0, isNull1, errors.Trace(err) + + if isNull0 || isNull1 { + return compareNull(isNull0, isNull1), true, nil } return int64(arg0.Compare(arg1)), false, nil } -func compareTime(ctx sessionctx.Context, args []Expression, row chunk.Row) (int64, bool, error) { - arg0, isNull0, err := args[0].EvalTime(ctx, row) - if isNull0 || err != nil { - return 0, isNull0, errors.Trace(err) +// CompareTime compares two datetime or timestamps. +func CompareTime(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) { + arg0, isNull0, err := lhsArg.EvalTime(sctx, lhsRow) + if err != nil { + return 0, true, err } - arg1, isNull1, err := args[1].EvalTime(ctx, row) - if isNull1 || err != nil { - return 0, isNull1, errors.Trace(err) + + arg1, isNull1, err := rhsArg.EvalTime(sctx, rhsRow) + if err != nil { + return 0, true, err + } + + if isNull0 || isNull1 { + return compareNull(isNull0, isNull1), true, nil } return int64(arg0.Compare(arg1)), false, nil } -func compareDuration(args []Expression, row chunk.Row, ctx sessionctx.Context) (int64, bool, error) { - arg0, isNull0, err := args[0].EvalDuration(ctx, row) - if isNull0 || err != nil { - return 0, isNull0, errors.Trace(err) +// CompareDuration compares two durations. +func CompareDuration(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) { + arg0, isNull0, err := lhsArg.EvalDuration(sctx, lhsRow) + if err != nil { + return 0, true, err } - arg1, isNull1, err := args[1].EvalDuration(ctx, row) - if isNull1 || err != nil { - return 0, isNull1, errors.Trace(err) + + arg1, isNull1, err := rhsArg.EvalDuration(sctx, rhsRow) + if err != nil { + return 0, true, err + } + + if isNull0 || isNull1 { + return compareNull(isNull0, isNull1), true, nil } return int64(arg0.Compare(arg1)), false, nil } -func compareJSON(ctx sessionctx.Context, args []Expression, row chunk.Row) (int64, bool, error) { - arg0, isNull0, err := args[0].EvalJSON(ctx, row) - if isNull0 || err != nil { - return 0, isNull0, errors.Trace(err) +// CompareJSON compares two JSONs. +func CompareJSON(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) { + arg0, isNull0, err := lhsArg.EvalJSON(sctx, lhsRow) + if err != nil { + return 0, true, err } - arg1, isNull1, err := args[1].EvalJSON(ctx, row) - if isNull1 || err != nil { - return 0, isNull1, errors.Trace(err) + + arg1, isNull1, err := rhsArg.EvalJSON(sctx, rhsRow) + if err != nil { + return 0, true, err + } + + if isNull0 || isNull1 { + return compareNull(isNull0, isNull1), true, nil } return int64(json.CompareBinary(arg0, arg1)), false, nil } diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 9de2863afee56..1369b13bec39e 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/ranger" + "github.com/pingcap/tidb/util/set" ) func (p *LogicalUnionScan) exhaustPhysicalPlans(prop *property.PhysicalProperty) []PhysicalPlan { @@ -65,20 +66,23 @@ func findMaxPrefixLen(candidates [][]*expression.Column, keys []*expression.Colu } func (p *LogicalJoin) moveEqualToOtherConditions(offsets []int) []expression.Expression { - otherConds := make([]expression.Expression, len(p.OtherConditions)) + // Construct used equal condition set based on the equal condition offsets. + usedEqConds := set.NewIntSet() + for _, eqCondIdx := range offsets { + usedEqConds.Insert(eqCondIdx) + } + + // Construct otherConds, which is composed of the original other conditions + // and the remained unused equal conditions. + numOtherConds := len(p.OtherConditions) + len(p.EqualConditions) - len(offsets) + otherConds := make([]expression.Expression, len(p.OtherConditions), numOtherConds) copy(otherConds, p.OtherConditions) - for i, eqCond := range p.EqualConditions { - match := false - for _, offset := range offsets { - if i == offset { - match = true - break - } - } - if !match { - otherConds = append(otherConds, eqCond) + for eqCondIdx := range p.EqualConditions { + if !usedEqConds.Exist(eqCondIdx) { + otherConds = append(otherConds, p.EqualConditions[eqCondIdx]) } } + return otherConds } @@ -135,6 +139,7 @@ func (p *LogicalJoin) getMergeJoin(prop *property.PhysicalProperty) []PhysicalPl }.init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt)) mergeJoin.SetSchema(p.schema) mergeJoin.OtherConditions = p.moveEqualToOtherConditions(offsets) + mergeJoin.initCompareFuncs() if reqProps, ok := mergeJoin.tryToGetChildReqProp(prop); ok { mergeJoin.childrenReqProps = reqProps joins = append(joins, mergeJoin) @@ -241,9 +246,32 @@ func (p *LogicalJoin) getEnforcedMergeJoin(prop *property.PhysicalProperty) []Ph }.init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt)) enforcedPhysicalMergeJoin.SetSchema(p.schema) enforcedPhysicalMergeJoin.childrenReqProps = []*property.PhysicalProperty{lProp, rProp} + enforcedPhysicalMergeJoin.initCompareFuncs() return []PhysicalPlan{enforcedPhysicalMergeJoin} } +func (p *PhysicalMergeJoin) initCompareFuncs() { + p.CompareFuncs = make([]expression.CompareFunc, 0, len(p.LeftKeys)) + for i := range p.LeftKeys { + switch expression.GetAccurateCmpType(p.LeftKeys[i], p.RightKeys[i]) { + case types.ETInt: + p.CompareFuncs = append(p.CompareFuncs, expression.CompareInt) + case types.ETReal: + p.CompareFuncs = append(p.CompareFuncs, expression.CompareReal) + case types.ETDecimal: + p.CompareFuncs = append(p.CompareFuncs, expression.CompareDecimal) + case types.ETString: + p.CompareFuncs = append(p.CompareFuncs, expression.CompareString) + case types.ETDuration: + p.CompareFuncs = append(p.CompareFuncs, expression.CompareDuration) + case types.ETDatetime, types.ETTimestamp: + p.CompareFuncs = append(p.CompareFuncs, expression.CompareTime) + case types.ETJson: + p.CompareFuncs = append(p.CompareFuncs, expression.CompareJSON) + } + } +} + func (p *LogicalJoin) getHashJoins(prop *property.PhysicalProperty) []PhysicalPlan { if !prop.IsEmpty() { // hash join doesn't promise any orders return nil diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 4c66330fabdfe..68de1e1e2481a 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -241,6 +241,7 @@ type PhysicalMergeJoin struct { JoinType JoinType + CompareFuncs []expression.CompareFunc LeftConditions []expression.Expression RightConditions []expression.Expression OtherConditions []expression.Expression diff --git a/util/set/int_set.go b/util/set/int_set.go new file mode 100644 index 0000000000000..9fef718e0c42a --- /dev/null +++ b/util/set/int_set.go @@ -0,0 +1,52 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package set + +// IntSet is a int set. +type IntSet map[int]struct{} + +// NewIntSet builds a IntSet. +func NewIntSet() IntSet { + return make(map[int]struct{}) +} + +// Exist checks whether `val` exists in `s`. +func (s IntSet) Exist(val int) bool { + _, ok := s[val] + return ok +} + +// Insert inserts `val` into `s`. +func (s IntSet) Insert(val int) { + s[val] = struct{}{} +} + +// Int64Set is a int64 set. +type Int64Set map[int64]struct{} + +// NewInt64Set builds a Int64Set. +func NewInt64Set() Int64Set { + return make(map[int64]struct{}) +} + +// Exist checks whether `val` exists in `s`. +func (s Int64Set) Exist(val int64) bool { + _, ok := s[val] + return ok +} + +// Insert inserts `val` into `s`. +func (s Int64Set) Insert(val int64) { + s[val] = struct{}{} +}