Skip to content

Commit

Permalink
executor, types: refactor CompareDatum (#29866)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjhuang2016 authored Nov 22, 2021
1 parent e482122 commit 2bf67ff
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 1 deletion.
9 changes: 8 additions & 1 deletion executor/distsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,13 @@ func (w *tableWorker) compareData(ctx context.Context, task *lookupTableTask, ta
chk := newFirstChunk(tableReader)
tblInfo := w.idxLookup.table.Meta()
vals := make([]types.Datum, 0, len(w.idxTblCols))

// Prepare collator for compare.
collators := make([]collate.Collator, 0, len(w.idxColTps))
for _, tp := range w.idxColTps {
collators = append(collators, collate.GetCollator(tp.Collate))
}

for {
err := Next(ctx, tableReader, chk)
if err != nil {
Expand Down Expand Up @@ -1169,7 +1176,7 @@ func (w *tableWorker) compareData(ctx context.Context, task *lookupTableTask, ta
tp := &col.FieldType
idxVal := idxRow.GetDatum(i, tp)
tablecodec.TruncateIndexValue(&idxVal, w.idxLookup.index.Columns[i], col.ColumnInfo)
cmpRes, err := idxVal.CompareDatum(sctx, &vals[i])
cmpRes, err := idxVal.Compare(sctx, &vals[i], collators[i])
if err != nil {
return ErrDataInConsistentMisMatchIndex.GenWithStackByArgs(col.Name,
handle, idxRow.GetDatum(i, tp), vals[i], err)
Expand Down
136 changes: 136 additions & 0 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/parser/types"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/hack"
)

Expand Down Expand Up @@ -548,7 +549,63 @@ func (d *Datum) SetValue(val interface{}, tp *types.FieldType) {
}
}

// Compare compares datum to another datum.
// Notes: don't rely on datum.collation to get the collator, it's tend to buggy.
// TODO: use this function to replace CompareDatum. After we remove all of usage of CompareDatum, we can rename this function back to CompareDatum.
func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collate.Collator) (int, error) {
if d.k == KindMysqlJSON && ad.k != KindMysqlJSON {
cmp, err := ad.Compare(sc, d, comparer)
return cmp * -1, errors.Trace(err)
}
switch ad.k {
case KindNull:
if d.k == KindNull {
return 0, nil
}
return 1, nil
case KindMinNotNull:
if d.k == KindNull {
return -1, nil
} else if d.k == KindMinNotNull {
return 0, nil
}
return 1, nil
case KindMaxValue:
if d.k == KindMaxValue {
return 0, nil
}
return -1, nil
case KindInt64:
return d.compareInt64(sc, ad.GetInt64())
case KindUint64:
return d.compareUint64(sc, ad.GetUint64())
case KindFloat32, KindFloat64:
return d.compareFloat64(sc, ad.GetFloat64())
case KindString:
return d.compareStringNew(sc, ad.GetString(), comparer)
case KindBytes:
return comparer.Compare(d.GetString(), ad.GetString()), nil
case KindMysqlDecimal:
return d.compareMysqlDecimal(sc, ad.GetMysqlDecimal())
case KindMysqlDuration:
return d.compareMysqlDuration(sc, ad.GetMysqlDuration())
case KindMysqlEnum:
return d.compareMysqlEnumNew(sc, ad.GetMysqlEnum(), comparer)
case KindBinaryLiteral, KindMysqlBit:
return d.compareBinaryLiteralNew(sc, ad.GetBinaryLiteral4Cmp(), comparer)
case KindMysqlSet:
return d.compareMysqlSetNew(sc, ad.GetMysqlSet(), comparer)
case KindMysqlJSON:
return d.compareMysqlJSON(sc, ad.GetMysqlJSON())
case KindMysqlTime:
return d.compareMysqlTime(sc, ad.GetMysqlTime())
default:
return 0, nil
}
}

// CompareDatum compares datum to another datum.
// Deprecated: will be replaced with Compare.
// TODO: return error properly.
func (d *Datum) CompareDatum(sc *stmtctx.StatementContext, ad *Datum) (int, error) {
if d.k == KindMysqlJSON && ad.k != KindMysqlJSON {
Expand Down Expand Up @@ -673,6 +730,39 @@ func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, er
}
}

func (d *Datum) compareStringNew(sc *stmtctx.StatementContext, s string, comparer collate.Collator) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
case KindMaxValue:
return 1, nil
case KindString, KindBytes:
return comparer.Compare(d.GetString(), s), nil
case KindMysqlDecimal:
dec := new(MyDecimal)
err := sc.HandleTruncate(dec.FromString(hack.Slice(s)))
return d.GetMysqlDecimal().Compare(dec), errors.Trace(err)
case KindMysqlTime:
dt, err := ParseDatetime(sc, s)
return d.GetMysqlTime().Compare(dt), errors.Trace(err)
case KindMysqlDuration:
dur, err := ParseDuration(sc, s, MaxFsp)
return d.GetMysqlDuration().Compare(dur), errors.Trace(err)
case KindMysqlSet:
return comparer.Compare(d.GetMysqlSet().String(), s), nil
case KindMysqlEnum:
return comparer.Compare(d.GetMysqlEnum().String(), s), nil
case KindBinaryLiteral, KindMysqlBit:
return comparer.Compare(d.GetBinaryLiteral4Cmp().String(), s), nil
default:
fVal, err := StrToFloat(sc, s, false)
if err != nil {
return 0, errors.Trace(err)
}
return d.compareFloat64(sc, fVal)
}
}

func (d *Datum) compareString(sc *stmtctx.StatementContext, s string, retCollation string) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
Expand Down Expand Up @@ -748,6 +838,52 @@ func (d *Datum) compareMysqlDuration(sc *stmtctx.StatementContext, dur Duration)
}
}

func (d *Datum) compareMysqlEnumNew(sc *stmtctx.StatementContext, enum Enum, comparer collate.Collator) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
case KindMaxValue:
return 1, nil
case KindString, KindBytes, KindMysqlEnum, KindMysqlSet:
return comparer.Compare(d.GetString(), enum.String()), nil
default:
return d.compareFloat64(sc, enum.ToNumber())
}
}

func (d *Datum) compareBinaryLiteralNew(sc *stmtctx.StatementContext, b BinaryLiteral, comparer collate.Collator) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
case KindMaxValue:
return 1, nil
case KindString, KindBytes:
fallthrough // in this case, d is converted to Binary and then compared with b
case KindBinaryLiteral, KindMysqlBit:
return comparer.Compare(d.GetBinaryLiteral4Cmp().ToString(), b.ToString()), nil
default:
val, err := b.ToInt(sc)
if err != nil {
return 0, errors.Trace(err)
}
result, err := d.compareFloat64(sc, float64(val))
return result, errors.Trace(err)
}
}

func (d *Datum) compareMysqlSetNew(sc *stmtctx.StatementContext, set Set, comparer collate.Collator) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
case KindMaxValue:
return 1, nil
case KindString, KindBytes, KindMysqlEnum, KindMysqlSet:
return comparer.Compare(d.GetString(), set.String()), nil
default:
return d.compareFloat64(sc, set.ToNumber())
}
}

func (d *Datum) compareMysqlEnum(sc *stmtctx.StatementContext, enum Enum) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
Expand Down

0 comments on commit 2bf67ff

Please sign in to comment.