From 147f954330b07af9941c0cf1b24cdfcb82a77445 Mon Sep 17 00:00:00 2001 From: siddontang Date: Mon, 7 Sep 2015 13:26:31 +0800 Subject: [PATCH 1/4] util:improve compare function, return error instead of panic. --- util/types/compare.go | 207 +++++++++++++++++++++++++++++++++++++ util/types/compare_test.go | 136 ++++++++++++++++++++++++ util/types/etc.go | 132 ++--------------------- util/types/etc_test.go | 75 -------------- util/types/helper.go | 52 ---------- 5 files changed, 350 insertions(+), 252 deletions(-) create mode 100644 util/types/compare.go create mode 100644 util/types/compare_test.go diff --git a/util/types/compare.go b/util/types/compare.go new file mode 100644 index 0000000000000..adc0d95703228 --- /dev/null +++ b/util/types/compare.go @@ -0,0 +1,207 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "github.com/juju/errors" + mysql "github.com/pingcap/tidb/mysqldef" +) + +// CompareInt64 returns an integer comparing the int64 x to y. +func CompareInt64(x, y int64) int { + if x < y { + return -1 + } else if x == y { + return 0 + } + + return 1 +} + +// CompareUint64 returns an integer comparing the uint64 x to y. +func CompareUint64(x, y uint64) int { + if x < y { + return -1 + } else if x == y { + return 0 + } + + return 1 +} + +// CompareFloat64 returns an integer comparing the float64 x to y. +func CompareFloat64(x, y float64) int { + if x < y { + return -1 + } else if x == y { + return 0 + } + + return 1 +} + +// CompareInteger returns an integer comparing the int64 x to the uint64 y. +func CompareInteger(x int64, y uint64) int { + if x < 0 { + return -1 + } + return CompareUint64(uint64(x), y) +} + +// CompareString returns an integer comparing the string x to y. +func CompareString(x, y string) int { + if x < y { + return -1 + } else if x == y { + return 0 + } + + return 1 +} + +// compareFloatString compares float a to float-formated string s. +// compareFloatString first parses s to a float value, if failed, returns error. +func compareFloatString(a float64, s string) (int, error) { + // MySQL will convert string to a float point value + // MySQL use a very loose conversation, e.g, 123.abc -> 123 + // we should do a trade off whether supporting this feature or using a strict mode + // now we use a strict mode + b, err := StrToFloat(s) + if err != nil { + return 0, err + } + return CompareFloat64(a, b), nil +} + +// compareStringFloat compares float-formated string s to float a. +func compareStringFloat(s string, a float64) (int, error) { + n, err := compareFloatString(a, s) + return -n, err +} + +func coerceCompare(a, b interface{}) (x interface{}, y interface{}) { + x, y = Coerce(a, b) + // change []byte to string for later compare + switch v := a.(type) { + case []byte: + x = string(v) + } + + switch v := b.(type) { + case []byte: + y = string(v) + } + + return x, y +} + +// Compare returns an integer comparing the interface a to b. +// a > b -> 1 +// a = b -> 0 +// a < b -> -1 +func Compare(a, b interface{}) (int, error) { + a, b = coerceCompare(a, b) + + if a == nil || b == nil { + // Check ni first, nil is always less than none nil value. + if a == nil && b != nil { + return -1, nil + } else if a != nil && b == nil { + return 1, nil + } else { + // here a and b are all nil + return 0, nil + } + } + + // TODO: support compare time type with other int, float, decimal types. + // TODO: support hexadecimal type + switch x := a.(type) { + case float64: + switch y := b.(type) { + case float64: + return CompareFloat64(x, y), nil + case string: + return compareFloatString(x, y) + } + case int64: + switch y := b.(type) { + case int64: + return CompareInt64(x, y), nil + case uint64: + return CompareInteger(x, y), nil + case string: + return compareFloatString(float64(x), y) + } + case uint64: + switch y := b.(type) { + case uint64: + return CompareUint64(x, y), nil + case int64: + return -CompareInteger(y, x), nil + case string: + return compareFloatString(float64(x), y) + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Cmp(y), nil + case string: + f, err := mysql.ConvertToDecimal(y) + if err != nil { + return 0, err + } + return x.Cmp(f), nil + } + case string: + switch y := b.(type) { + case string: + return CompareString(x, y), nil + case int64: + return compareStringFloat(x, float64(y)) + case uint64: + return compareStringFloat(x, float64(y)) + case float64: + return compareStringFloat(x, y) + case mysql.Decimal: + f, err := mysql.ConvertToDecimal(x) + if err != nil { + return 0, err + } + return f.Cmp(y), nil + case mysql.Time: + n, err := y.CompareString(x) + return -n, err + case mysql.Duration: + n, err := y.CompareString(x) + return -n, err + } + case mysql.Time: + switch y := b.(type) { + case mysql.Time: + return x.Compare(y), nil + case string: + return x.CompareString(y) + } + case mysql.Duration: + switch y := b.(type) { + case mysql.Duration: + return x.Compare(y), nil + case string: + return x.CompareString(y) + } + } + + return 0, errors.Errorf("invalid comapre type %T cmp %T", a, b) +} diff --git a/util/types/compare_test.go b/util/types/compare_test.go new file mode 100644 index 0000000000000..8f89d298834f0 --- /dev/null +++ b/util/types/compare_test.go @@ -0,0 +1,136 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "time" + + . "github.com/pingcap/check" + mysql "github.com/pingcap/tidb/mysqldef" +) + +var _ = Suite(&testCompareSuite{}) + +type testCompareSuite struct { +} + +func (s *testCompareSuite) TestCompare(c *C) { + cmpTbl := []struct { + lhs interface{} + rhs interface{} + ret int // 0, 1, -1 + }{ + {float64(1), float64(1), 0}, + {float64(1), "1", 0}, + {int64(1), int64(1), 0}, + {int64(-1), uint64(1), -1}, + {int64(-1), "-1", 0}, + {uint64(1), uint64(1), 0}, + {uint64(1), int64(-1), 1}, + {uint64(1), "1", 0}, + {mysql.NewDecimalFromInt(1, 0), mysql.NewDecimalFromInt(1, 0), 0}, + {mysql.NewDecimalFromInt(1, 0), "1", 0}, + {"1", "1", 0}, + {"1", int64(-1), 1}, + {"1", float64(2), -1}, + {"1", uint64(1), 0}, + {"1", mysql.NewDecimalFromInt(1, 0), 0}, + {"2011-01-01 11:11:11", mysql.Time{time.Now(), mysql.TypeDatetime, 0}, -1}, + {"12:00:00", mysql.ZeroDuration, 1}, + {mysql.ZeroDuration, mysql.ZeroDuration, 0}, + {mysql.Time{time.Now().Add(time.Second * 10), mysql.TypeDatetime, 0}, + mysql.Time{time.Now(), mysql.TypeDatetime, 0}, 1}, + + {nil, 2, -1}, + {nil, nil, 0}, + + {false, nil, 1}, + {false, true, -1}, + {true, true, 0}, + {false, false, 0}, + {true, 2, -1}, + + {float64(1.23), nil, 1}, + {float64(1.23), float32(3.45), -1}, + {float64(354.23), float32(3.45), 1}, + {float64(0.0), float64(3.45), -1}, + {float64(354.23), float64(3.45), 1}, + {float64(3.452), float64(3.452), 0}, + {float32(1.23), nil, 1}, + + {int(432), nil, 1}, + {-4, int(32), -1}, + {int(4), -32, 1}, + {int(432), int8(12), 1}, + {int(23), int8(28), -1}, + {int(123), int8(123), 0}, + {int(432), int16(12), 1}, + {int(23), int16(128), -1}, + {int(123), int16(123), 0}, + {int(432), int32(12), 1}, + {int(23), int32(128), -1}, + {int(123), int32(123), 0}, + {int(432), int64(12), 1}, + {int(23), int64(128), -1}, + {int(123), int64(123), 0}, + {int(432), int(12), 1}, + {int(23), int(123), -1}, + {int8(3), int(3), 0}, + {int16(923), int(180), 1}, + {int32(173), int(120), 1}, + {int64(133), int(183), -1}, + + {uint(23), nil, 1}, + {uint(23), uint(123), -1}, + {uint8(3), uint8(3), 0}, + {uint16(923), uint16(180), 1}, + {uint32(173), uint32(120), 1}, + {uint64(133), uint64(183), -1}, + {uint64(2), int64(-2), 1}, + {uint64(2), int64(1), 1}, + + {"", nil, 1}, + {"", "24", -1}, + {"aasf", "4", 1}, + {"", "", 0}, + + {[]byte(""), nil, 1}, + {[]byte(""), []byte("sff"), -1}, + + {mysql.Time{}, nil, 1}, + {mysql.Time{}, mysql.Time{Time: time.Now(), Type: mysql.TypeDatetime, Fsp: 3}, -1}, + {mysql.Time{Time: time.Now(), Type: mysql.TypeDatetime, Fsp: 3}, "0000-00-00 00:00:00", 1}, + + {mysql.Duration{Duration: time.Duration(34), Fsp: 2}, nil, 1}, + {mysql.Duration{Duration: time.Duration(34), Fsp: 2}, mysql.Duration{Duration: time.Duration(29034), Fsp: 2}, -1}, + {mysql.Duration{Duration: time.Duration(3340), Fsp: 2}, mysql.Duration{Duration: time.Duration(34), Fsp: 2}, 1}, + {mysql.Duration{Duration: time.Duration(34), Fsp: 2}, mysql.Duration{Duration: time.Duration(34), Fsp: 2}, 0}, + + {[]byte{}, []byte{}, 0}, + {[]byte("abc"), []byte("ab"), 1}, + {[]byte("123"), 1234, -1}, + {[]byte{}, nil, 1}, + } + + for _, t := range cmpTbl { + ret, err := Compare(t.lhs, t.rhs) + c.Assert(err, IsNil) + c.Assert(ret, Equals, t.ret) + + ret, err = Compare(t.rhs, t.lhs) + c.Assert(err, IsNil) + c.Assert(ret, Equals, -t.ret) + } + +} diff --git a/util/types/etc.go b/util/types/etc.go index 2ccc87d01bbb7..53c071ee510a2 100644 --- a/util/types/etc.go +++ b/util/types/etc.go @@ -14,7 +14,6 @@ package types import ( - "bytes" "fmt" "io" "reflect" @@ -242,129 +241,6 @@ func compareFloat64With(x float64, b interface{}) int { } } -// Compare returns an integer comparing the interface a to b. -// TODO: compare should return errors instead of panicing. -func Compare(a, b interface{}) int { - switch x := a.(type) { - case nil: - if b != nil { - return -1 - } - - return 0 - case bool: - switch y := b.(type) { - case nil: - return 1 - case bool: - if !x && y { - return -1 - } - - if x == y { - return 0 - } - - return 1 - default: - // Make bool collate before anything except nil and - // other bool for index seeking first non NULL value. - return -1 - } - case float32: - return compareFloat64With(float64(x), b) - case float64: - return compareFloat64With(float64(x), b) - case int8: - return compareInt64With(int64(x), b) - case int16: - return compareInt64With(int64(x), b) - case int32: - return compareInt64With(int64(x), b) - case int: - return compareInt64With(int64(x), b) - case int64: - return compareInt64With(int64(x), b) - case uint8: - return compareUint64With(uint64(x), b) - case uint16: - return compareUint64With(uint64(x), b) - case uint: - return compareUint64With(uint64(x), b) - case uint32: - return compareUint64With(uint64(x), b) - case uint64: - return compareUint64With(uint64(x), b) - case string: - switch y := b.(type) { - case nil: - return 1 - case string: - if x < y { - return -1 - } - - if x == y { - return 0 - } - - return 1 - default: - panic("should never happen") - } - case []byte: - switch y := b.(type) { - case nil: - return 1 - case []byte: - return bytes.Compare(x, y) - default: - panic("should never happen") - } - case mysql.Time: - switch y := b.(type) { - case nil: - return 1 - case mysql.Time: - return x.Compare(y) - case string: - t, err := mysql.ParseTime(y, x.Type, x.Fsp) - if err != nil { - log.Warnf("Failed to convert %s to mysql.Time with err %v", y, err) - return 1 - } - return x.Compare(t) - default: - panic("should never happen") - } - case mysql.Duration: - switch y := b.(type) { - case nil: - return 1 - case mysql.Duration: - if x.Duration < y.Duration { - return -1 - } - - if x.Duration == y.Duration { - return 0 - } - - return 1 - default: - panic("should never happen") - } - case mysql.Decimal: - y, err := ToDecimal(b) - if err != nil { - panic(fmt.Sprintf("should never happen, err: %v", err)) - } - return x.Cmp(y) - default: - panic("should never happen") - } -} - // TODO: collate should return errors from Compare. func collate(x, y []interface{}) (r int) { nx, ny := len(x), len(y) @@ -384,7 +260,13 @@ func collate(x, y []interface{}) (r int) { } for i, xi := range x { - if c := Compare(xi, y[i]); c != 0 { + // TODO: we may remove collate later, so here just panic error. + c, err := Compare(xi, y[i]) + if err != nil { + panic(fmt.Sprintf("should never happend %v", err)) + } + + if c != 0 { return c * r } } diff --git a/util/types/etc_test.go b/util/types/etc_test.go index a4ee136be43cc..7b99167e5206c 100644 --- a/util/types/etc_test.go +++ b/util/types/etc_test.go @@ -102,81 +102,6 @@ func (s *testTypeEtcSuite) TestEOFAsNil(c *C) { c.Assert(err, IsNil) } -func checkCompare(c *C, x, y interface{}, expect int) { - v := Compare(x, y) - c.Assert(v, Equals, expect) -} - -func (s *testTypeEtcSuite) TestCompare(c *C) { - checkCompare(c, nil, 2, -1) - checkCompare(c, nil, nil, 0) - - checkCompare(c, false, nil, 1) - checkCompare(c, false, true, -1) - checkCompare(c, true, true, 0) - checkCompare(c, false, false, 0) - checkCompare(c, true, "", -1) - - checkCompare(c, float64(1.23), nil, 1) - checkCompare(c, float64(1.23), float32(3.45), -1) - checkCompare(c, float64(354.23), float32(3.45), 1) - checkCompare(c, float64(0.0), float64(3.45), -1) - checkCompare(c, float64(354.23), float64(3.45), 1) - checkCompare(c, float64(3.452), float64(3.452), 0) - checkCompare(c, float32(1.23), nil, 1) - - checkCompare(c, int(432), nil, 1) - checkCompare(c, -4, int(32), -1) - checkCompare(c, int(4), -32, 1) - checkCompare(c, int(432), int8(12), 1) - checkCompare(c, int(23), int8(28), -1) - checkCompare(c, int(123), int8(123), 0) - checkCompare(c, int(432), int16(12), 1) - checkCompare(c, int(23), int16(128), -1) - checkCompare(c, int(123), int16(123), 0) - checkCompare(c, int(432), int32(12), 1) - checkCompare(c, int(23), int32(128), -1) - checkCompare(c, int(123), int32(123), 0) - checkCompare(c, int(432), int64(12), 1) - checkCompare(c, int(23), int64(128), -1) - checkCompare(c, int(123), int64(123), 0) - checkCompare(c, int(432), int(12), 1) - checkCompare(c, int(23), int(123), -1) - checkCompare(c, int8(3), int(3), 0) - checkCompare(c, int16(923), int(180), 1) - checkCompare(c, int32(173), int(120), 1) - checkCompare(c, int64(133), int(183), -1) - - checkCompare(c, uint(23), nil, 1) - checkCompare(c, uint(23), uint(123), -1) - checkCompare(c, uint8(3), uint8(3), 0) - checkCompare(c, uint16(923), uint16(180), 1) - checkCompare(c, uint32(173), uint32(120), 1) - checkCompare(c, uint64(133), uint64(183), -1) - - checkCompare(c, "", nil, 1) - checkCompare(c, "", "24", -1) - checkCompare(c, "aasf", "4", 1) - checkCompare(c, "", "", 0) - - checkCompare(c, []byte(""), nil, 1) - checkCompare(c, []byte(""), []byte("sff"), -1) - - checkCompare(c, mysql.Time{}, nil, 1) - checkCompare(c, mysql.Time{}, mysql.Time{Time: time.Now(), Type: 1, Fsp: 3}, -1) - checkCompare(c, mysql.Time{Time: time.Now(), Type: 1, Fsp: 3}, "11:11", 1) - - checkCompare(c, mysql.Duration{Duration: time.Duration(34), Fsp: 2}, nil, 1) - checkCompare(c, mysql.Duration{Duration: time.Duration(34), Fsp: 2}, - mysql.Duration{Duration: time.Duration(29034), Fsp: 2}, -1) - checkCompare(c, mysql.Duration{Duration: time.Duration(3340), Fsp: 2}, - mysql.Duration{Duration: time.Duration(34), Fsp: 2}, 1) - checkCompare(c, mysql.Duration{Duration: time.Duration(34), Fsp: 2}, - mysql.Duration{Duration: time.Duration(34), Fsp: 2}, 0) - - checkCompare(c, mysql.Decimal{}, mysql.Decimal{}, 0) -} - func checkCollate(c *C, x, y []interface{}, expect int) { v := collate(x, y) c.Assert(v, Equals, expect) diff --git a/util/types/helper.go b/util/types/helper.go index 036969ba3b4be..4ab32208df10b 100644 --- a/util/types/helper.go +++ b/util/types/helper.go @@ -36,58 +36,6 @@ func RoundFloat(val float64) float64 { return v } -// CompareInt64 returns an integer comparing the int64 x to y. -func CompareInt64(x, y int64) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} - -// CompareUint64 returns an integer comparing the uint64 x to y. -func CompareUint64(x, y uint64) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} - -// CompareFloat64 returns an integer comparing the float64 x to y. -func CompareFloat64(x, y float64) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} - -// CompareInteger returns an integer comparing the int64 x to the uint64 y. -func CompareInteger(x int64, y uint64) int { - if x < 0 { - return -1 - } - return CompareUint64(uint64(x), y) -} - -// CompareString returns an integer comparing the string x to y. -func CompareString(x, y string) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} - func getMaxFloat(flen int, decimal int) float64 { intPartLen := flen - decimal f := math.Pow10(intPartLen) From dd19e36c2839345cfae8d8fa46a77d74cdc395fa Mon Sep 17 00:00:00 2001 From: siddontang Date: Mon, 7 Sep 2015 13:26:54 +0800 Subject: [PATCH 2/4] expressions:use types.Compare --- expression/expressions/binop.go | 101 +------------------------------- 1 file changed, 1 insertion(+), 100 deletions(-) diff --git a/expression/expressions/binop.go b/expression/expressions/binop.go index efb81fc702d19..c45228207c40c 100644 --- a/expression/expressions/binop.go +++ b/expression/expressions/binop.go @@ -346,105 +346,6 @@ func (o *BinaryOperation) evalLogicOp(ctx context.Context, args map[interface{}] } } -func compareFloatString(a float64, s string) (int, error) { - // MySQL will convert string to a float point value - // MySQL use a very loose conversation, e.g, 123.abc -> 123 - // we should do a trade off whether supporting this feature or using a strict mode - // now we use a strict mode - b, err := types.StrToFloat(s) - if err != nil { - return 0, err - } - return types.CompareFloat64(a, b), nil -} - -func compareStringFloat(s string, a float64) (int, error) { - n, err := compareFloatString(a, s) - return -n, err -} - -// See https://dev.mysql.com/doc/refman/5.7/en/type-conversion.html -func evalCompare(a interface{}, b interface{}) (int, error) { - // TODO: support compare time type with other types - switch x := a.(type) { - case float64: - switch y := b.(type) { - case float64: - return types.CompareFloat64(x, y), nil - case string: - return compareFloatString(x, y) - } - case int64: - switch y := b.(type) { - case int64: - return types.CompareInt64(x, y), nil - case uint64: - return types.CompareInteger(x, y), nil - case string: - return compareFloatString(float64(x), y) - } - case uint64: - switch y := b.(type) { - case uint64: - return types.CompareUint64(x, y), nil - case int64: - return -types.CompareInteger(y, x), nil - case string: - return compareFloatString(float64(x), y) - } - case mysql.Decimal: - switch y := b.(type) { - case mysql.Decimal: - return x.Cmp(y), nil - case string: - f, err := mysql.ConvertToDecimal(y) - if err != nil { - return 0, err - } - return x.Cmp(f), nil - } - case string: - switch y := b.(type) { - case string: - return types.CompareString(x, y), nil - case int64: - return compareStringFloat(x, float64(y)) - case uint64: - return compareStringFloat(x, float64(y)) - case float64: - return compareStringFloat(x, y) - case mysql.Decimal: - f, err := mysql.ConvertToDecimal(x) - if err != nil { - return 0, err - } - return f.Cmp(y), nil - case mysql.Time: - n, err := y.CompareString(x) - return -n, err - case mysql.Duration: - n, err := y.CompareString(x) - return -n, err - } - case mysql.Time: - switch y := b.(type) { - case mysql.Time: - return x.Compare(y), nil - case string: - return x.CompareString(y) - } - case mysql.Duration: - switch y := b.(type) { - case mysql.Duration: - return x.Compare(y), nil - case string: - return x.CompareString(y) - } - } - - return 0, errors.Errorf("invalid compare type %T cmp %T", a, b) -} - // operator: >=, >, <=, <, !=, <>, = <=>, etc. // see https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html func (o *BinaryOperation) evalComparisonOp(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { @@ -459,7 +360,7 @@ func (o *BinaryOperation) evalComparisonOp(ctx context.Context, args map[interfa return nil, nil } - n, err := evalCompare(a, b) + n, err := types.Compare(a, b) if err != nil { return nil, o.traceErr(err) } From 329721d226abd0f50d7ca44c430687f25f27441b Mon Sep 17 00:00:00 2001 From: siddontang Date: Mon, 7 Sep 2015 13:28:32 +0800 Subject: [PATCH 3/4] *:handle error with types.Compare function --- expression/expressions/builtin_control.go | 5 +---- expression/expressions/builtin_groupby.go | 12 ++++++++++-- expression/expressions/unary_test.go | 3 ++- plan/plans/index.go | 10 +++++++++- plan/plans/orderby.go | 9 ++++++++- stmt/stmts/update.go | 7 ++++++- 6 files changed, 36 insertions(+), 10 deletions(-) diff --git a/expression/expressions/builtin_control.go b/expression/expressions/builtin_control.go index 4f34b60cb2a12..811975863ee55 100644 --- a/expression/expressions/builtin_control.go +++ b/expression/expressions/builtin_control.go @@ -65,10 +65,7 @@ func builtinNullIf(args []interface{}, m map[interface{}]interface{}) (interface return v1, nil } - // coerce for later eval compare - x, y := types.Coerce(v1, v2) - - if n, err := evalCompare(x, y); err != nil || n == 0 { + if n, err := types.Compare(v1, v2); err != nil || n == 0 { return nil, err } diff --git a/expression/expressions/builtin_groupby.go b/expression/expressions/builtin_groupby.go index 8e5bf9fa2f875..c6b9ee0a8e2b3 100644 --- a/expression/expressions/builtin_groupby.go +++ b/expression/expressions/builtin_groupby.go @@ -255,7 +255,11 @@ func builtinMax(args []interface{}, ctx map[interface{}]interface{}) (v interfac if max == nil { max = y } else { - if types.Compare(max, y) < 0 { + n, err := types.Compare(max, y) + if err != nil { + return nil, errors.Trace(err) + } + if n < 0 { max = y } } @@ -288,7 +292,11 @@ func builtinMin(args []interface{}, ctx map[interface{}]interface{}) (v interfac if min == nil { min = y } else { - if types.Compare(min, y) > 0 { + n, err := types.Compare(min, y) + if err != nil { + return nil, errors.Trace(err) + } + if n > 0 { min = y } } diff --git a/expression/expressions/unary_test.go b/expression/expressions/unary_test.go index bc539e4fb900e..75c671acb36bf 100644 --- a/expression/expressions/unary_test.go +++ b/expression/expressions/unary_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/expression" mysql "github.com/pingcap/tidb/mysqldef" "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" ) var _ = Suite(&testUnaryOperationSuite{}) @@ -111,7 +112,7 @@ func (s *testUnaryOperationSuite) TestUnaryOp(c *C) { result, err := exprc.Eval(nil, nil) c.Assert(err, IsNil) - ret, err := evalCompare(result, t.result) + ret, err := types.Compare(result, t.result) c.Assert(err, IsNil) c.Assert(ret, Equals, 0) } diff --git a/plan/plans/index.go b/plan/plans/index.go index fb917ff6e177e..e9da52d0baf1c 100644 --- a/plan/plans/index.go +++ b/plan/plans/index.go @@ -14,6 +14,8 @@ package plans import ( + "fmt" + "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" @@ -151,7 +153,13 @@ func indexCompare(a interface{}, b interface{}) int { return -1 } - return types.Compare(a, b) + n, err := types.Compare(a, b) + if err != nil { + // Old compare panics if err, so here we do the same thing now. + // TODO: return err instead of panic. + panic(fmt.Sprintf("should never happend %v", err)) + } + return n } func (r *indexPlan) doSpan(ctx context.Context, txn kv.Transaction, span *indexSpan, f plan.RowIterFunc) error { diff --git a/plan/plans/orderby.go b/plan/plans/orderby.go index 3145e6789bf03..5961f797cfd29 100644 --- a/plan/plans/orderby.go +++ b/plan/plans/orderby.go @@ -19,6 +19,7 @@ import ( "strings" "github.com/juju/errors" + "github.com/ngaut/log" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/expressions" @@ -102,7 +103,13 @@ func (t *orderByTable) Less(i, j int) bool { v1 := t.Rows[i].Key[index] v2 := t.Rows[j].Key[index] - ret := types.Compare(v1, v2) + ret, err := types.Compare(v1, v2) + if err != nil { + // we just have to log this error and skip it. + // TODO: record this error and handle it out later. + log.Errorf("compare %v %v err %v", v1, v2, err) + } + if !asc { ret = -ret } diff --git a/stmt/stmts/update.go b/stmt/stmts/update.go index 1d1ff642b649d..249d9805c555b 100644 --- a/stmt/stmts/update.go +++ b/stmt/stmts/update.go @@ -167,7 +167,12 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl continue } od := oldData[i] - if types.Compare(d, od) != 0 { + n, err := types.Compare(d, od) + if err != nil { + return errors.Trace(err) + } + + if n != 0 { rowChanged = true break } From 213793ba3808c5bff8ef8795fa77f1e5b476b400 Mon Sep 17 00:00:00 2001 From: siddontang Date: Mon, 7 Sep 2015 13:28:58 +0800 Subject: [PATCH 4/4] expressions:remove unnecessary test --- expression/expressions/binop_test.go | 42 ---------------------------- util/types/compare_test.go | 6 ++-- 2 files changed, 3 insertions(+), 45 deletions(-) diff --git a/expression/expressions/binop_test.go b/expression/expressions/binop_test.go index fe8ab53cd94ec..0d90dc49d35c0 100644 --- a/expression/expressions/binop_test.go +++ b/expression/expressions/binop_test.go @@ -15,7 +15,6 @@ package expressions import ( "errors" - "time" . "github.com/pingcap/check" "github.com/pingcap/tidb/expression" @@ -106,44 +105,6 @@ func (s *testBinOpSuite) TestComparisonOp(c *C) { c.Assert(v, IsNil) } - // test evalCompare function - cmpTbl := []struct { - lhs interface{} - rhs interface{} - ret int // 0, 1, -1 - }{ - {float64(1), float64(1), 0}, - {float64(1), "1", 0}, - {int64(1), int64(1), 0}, - {int64(-1), uint64(1), -1}, - {int64(-1), "-1", 0}, - {uint64(1), uint64(1), 0}, - {uint64(1), int64(-1), 1}, - {uint64(1), "1", 0}, - {mysql.NewDecimalFromInt(1, 0), mysql.NewDecimalFromInt(1, 0), 0}, - {mysql.NewDecimalFromInt(1, 0), "1", 0}, - {"1", "1", 0}, - {"1", int64(-1), 1}, - {"1", float64(2), -1}, - {"1", uint64(1), 0}, - {"1", mysql.NewDecimalFromInt(1, 0), 0}, - {"2011-01-01 11:11:11", mysql.Time{Time: time.Now(), Type: mysql.TypeDatetime, Fsp: 0}, -1}, - {"12:00:00", mysql.ZeroDuration, 1}, - {mysql.ZeroDuration, mysql.ZeroDuration, 0}, - {mysql.Time{Time: time.Now().Add(time.Second * 10), Type: mysql.TypeDatetime, Fsp: 0}, - mysql.Time{Time: time.Now(), Type: mysql.TypeDatetime, Fsp: 0}, 1}, - } - - for _, t := range cmpTbl { - ret, err := evalCompare(t.lhs, t.rhs) - c.Assert(err, IsNil) - c.Assert(ret, Equals, t.ret) - - ret, err = evalCompare(t.rhs, t.lhs) - c.Assert(err, IsNil) - c.Assert(ret, Equals, -t.ret) - } - // test error mock := mockExpr{ isStatic: false, @@ -194,9 +155,6 @@ func (s *testBinOpSuite) TestComparisonOp(c *C) { expr := BinaryOperation{Op: opcode.Plus, L: Value{1}, R: Value{1}} _, err := expr.evalComparisonOp(nil, nil) c.Assert(err, NotNil) - - _, err = evalCompare(1, 1) - c.Assert(err, NotNil) } func (s *testBinOpSuite) TestIdentRelOp(c *C) { diff --git a/util/types/compare_test.go b/util/types/compare_test.go index 8f89d298834f0..64372c124198f 100644 --- a/util/types/compare_test.go +++ b/util/types/compare_test.go @@ -46,11 +46,11 @@ func (s *testCompareSuite) TestCompare(c *C) { {"1", float64(2), -1}, {"1", uint64(1), 0}, {"1", mysql.NewDecimalFromInt(1, 0), 0}, - {"2011-01-01 11:11:11", mysql.Time{time.Now(), mysql.TypeDatetime, 0}, -1}, + {"2011-01-01 11:11:11", mysql.Time{Time: time.Now(), Type: mysql.TypeDatetime, Fsp: 0}, -1}, {"12:00:00", mysql.ZeroDuration, 1}, {mysql.ZeroDuration, mysql.ZeroDuration, 0}, - {mysql.Time{time.Now().Add(time.Second * 10), mysql.TypeDatetime, 0}, - mysql.Time{time.Now(), mysql.TypeDatetime, 0}, 1}, + {mysql.Time{Time: time.Now().Add(time.Second * 10), Type: mysql.TypeDatetime, Fsp: 0}, + mysql.Time{Time: time.Now(), Type: mysql.TypeDatetime, Fsp: 0}, 1}, {nil, 2, -1}, {nil, nil, 0},