diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 0430499bbc09d..9983aa9f48d1d 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -174,7 +174,42 @@ type decodeFunctionClass struct { } func (c *decodeFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - return nil, errFunctionNotExists.GenByArgs("FUNCTION", "DECODE") + if err := c.verifyArgs(args); err != nil { + return nil, errors.Trace(err) + } + + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString) + + bf.tp.Flen = args[0].GetType().Flen + sig := &builtinDecodeSig{bf} + return sig, nil +} + +type builtinDecodeSig struct { + baseBuiltinFunc +} + +func (b *builtinDecodeSig) Clone() builtinFunc { + newSig := &builtinDecodeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals DECODE(str, password_str). +// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_decode +func (b *builtinDecodeSig) evalString(row chunk.Row) (string, bool, error) { + dataStr, isNull, err := b.args[0].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + + passwordStr, isNull, err := b.args[1].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + + decodeStr, err := encrypt.SQLDecode(dataStr, passwordStr) + return decodeStr, false, err } type desDecryptFunctionClass struct { @@ -198,7 +233,42 @@ type encodeFunctionClass struct { } func (c *encodeFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - return nil, errFunctionNotExists.GenByArgs("FUNCTION", "ENCODE") + if err := c.verifyArgs(args); err != nil { + return nil, errors.Trace(err) + } + + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString) + + bf.tp.Flen = args[0].GetType().Flen + sig := &builtinEncodeSig{bf} + return sig, nil +} + +type builtinEncodeSig struct { + baseBuiltinFunc +} + +func (b *builtinEncodeSig) Clone() builtinFunc { + newSig := &builtinEncodeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals ENCODE(crypt_str, password_str). +// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_encode +func (b *builtinEncodeSig) evalString(row chunk.Row) (string, bool, error) { + decodeStr, isNull, err := b.args[0].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + + passwordStr, isNull, err := b.args[1].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + + dataStr, err := encrypt.SQLEncode(decodeStr, passwordStr) + return dataStr, false, err } type encryptFunctionClass struct { diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 5d80199f15778..6fa7e90e22dc8 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -26,6 +26,56 @@ import ( "github.com/pingcap/tidb/util/testleak" ) +var cryptTests = []struct { + origin interface{} + password interface{} + crypt interface{} +}{ + {"", "", ""}, + {"pingcap", "1234567890123456", "2C35B5A4ADF391"}, + {"pingcap", "asdfjasfwefjfjkj", "351CC412605905"}, + {"pingcap123", "123456789012345678901234", "7698723DC6DFE7724221"}, + {"pingcap#%$%^", "*^%YTu1234567", "8634B9C55FF55E5B6328F449"}, + {"pingcap", "", "4A77B524BD2C5C"}, + {"分布式データベース", "pass1234@#$%%^^&", "80CADC8D328B3026D04FB285F36FED04BBCA0CC685BF78B1E687CE"}, + {"分布式データベース", "分布式7782734adgwy1242", "0E24CFEF272EE32B6E0BFBDB89F29FB43B4B30DAA95C3F914444BC"}, + {"pingcap", "密匙", "CE5C02A5010010"}, + {"pingcap数据库", "数据库passwd12345667", "36D5F90D3834E30E396BE3226E3B4ED3"}, + {"数据库5667", 123.435, "B22196D0569386237AE12F8AAB"}, + {nil, "数据库passwd12345667", nil}, +} + +func (s *testEvaluatorSuite) TestSQLDecode(c *C) { + defer testleak.AfterTest(c)() + fc := funcs[ast.Decode] + for _, tt := range cryptTests { + str := types.NewDatum(tt.origin) + password := types.NewDatum(tt.password) + + f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{str, password})) + crypt, err := evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(toHex(crypt), DeepEquals, types.NewDatum(tt.crypt)) + } + s.testNullInput(c, ast.Decode) +} + +func (s *testEvaluatorSuite) TestSQLEncode(c *C) { + defer testleak.AfterTest(c)() + fc := funcs[ast.Encode] + for _, test := range cryptTests { + password := types.NewDatum(test.password) + cryptStr := fromHex(test.crypt) + + f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{cryptStr, password})) + str, err := evalBuiltinFunc(f, chunk.Row{}) + + c.Assert(err, IsNil) + c.Assert(str, DeepEquals, types.NewDatum(test.origin)) + } + s.testNullInput(c, ast.Encode) +} + var aesTests = []struct { origin interface{} key interface{} diff --git a/util/encrypt/crypt.go b/util/encrypt/crypt.go new file mode 100644 index 0000000000000..578486f4e9a8a --- /dev/null +++ b/util/encrypt/crypt.go @@ -0,0 +1,138 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package encrypt + +type randStruct struct { + seed1 uint32 + seed2 uint32 + maxValue uint32 + maxValueDbl float64 +} + +// randomInit random generation structure initialization +func (rs *randStruct) randomInit(password []byte, length int) { + // Generate binary hash from raw text string + var nr, add, nr2, tmp uint32 + nr = 1345345333 + add = 7 + nr2 = 0x12345671 + + for i := 0; i < length; i++ { + pswChar := password[i] + if pswChar == ' ' || pswChar == '\t' { + continue + } + tmp = uint32(pswChar) + nr ^= (((nr & 63) + add) * tmp) + (nr << 8) + nr2 += (nr2 << 8) ^ nr + add += tmp + } + + seed1 := nr & ((uint32(1) << 31) - uint32(1)) + seed2 := nr2 & ((uint32(1) << 31) - uint32(1)) + + // New (MySQL 3.21+) random generation structure initialization + rs.maxValue = 0x3FFFFFFF + rs.maxValueDbl = float64(rs.maxValue) + rs.seed1 = seed1 % rs.maxValue + rs.seed2 = seed2 % rs.maxValue +} + +func (rs *randStruct) myRand() float64 { + rs.seed1 = (rs.seed1*3 + rs.seed2) % rs.maxValue + rs.seed2 = (rs.seed1 + rs.seed2 + 33) % rs.maxValue + + return ((float64(rs.seed1)) / rs.maxValueDbl) +} + +// sqlCrypt use to store initialization results +type sqlCrypt struct { + rand randStruct + orgRand randStruct + + decodeBuff [256]byte + encodeBuff [256]byte + shift uint32 +} + +func (sc *sqlCrypt) init(password []byte, length int) { + sc.rand.randomInit(password, length) + + for i := 0; i <= 255; i++ { + sc.decodeBuff[i] = byte(i) + } + + for i := 0; i <= 255; i++ { + idx := uint32(sc.rand.myRand() * 255.0) + a := sc.decodeBuff[idx] + sc.decodeBuff[idx] = sc.decodeBuff[i] + sc.decodeBuff[i] = a + } + + for i := 0; i <= 255; i++ { + sc.encodeBuff[sc.decodeBuff[i]] = byte(i) + } + + sc.orgRand = sc.rand + sc.shift = 0 +} + +func (sc *sqlCrypt) reinit() { + sc.shift = 0 + sc.rand = sc.orgRand +} + +func (sc *sqlCrypt) encode(str []byte, length int) { + for i := 0; i < length; i++ { + sc.shift ^= uint32(sc.rand.myRand() * 255.0) + idx := uint32(str[i]) + str[i] = sc.encodeBuff[idx] ^ byte(sc.shift) + sc.shift ^= idx + } +} + +func (sc *sqlCrypt) decode(str []byte, length int) { + for i := 0; i < length; i++ { + sc.shift ^= uint32(sc.rand.myRand() * 255.0) + idx := uint32(str[i] ^ byte(sc.shift)) + str[i] = sc.decodeBuff[idx] + sc.shift ^= uint32(str[i]) + } +} + +//SQLDecode Function to handle the decode() function +func SQLDecode(str string, password string) (string, error) { + var sc sqlCrypt + + strByte := []byte(str) + passwdByte := []byte(password) + + sc.init(passwdByte, len(passwdByte)) + sc.decode(strByte, len(strByte)) + + return string(strByte), nil +} + +// SQLEncode Function to handle the encode() function +func SQLEncode(cryptStr string, password string) (string, error) { + var sc sqlCrypt + + cryptStrByte := []byte(cryptStr) + passwdByte := []byte(password) + + sc.init(passwdByte, len(passwdByte)) + sc.encode(cryptStrByte, len(cryptStrByte)) + + return string(cryptStrByte), nil +} diff --git a/util/encrypt/crypt_test.go b/util/encrypt/crypt_test.go new file mode 100644 index 0000000000000..524bae486aa39 --- /dev/null +++ b/util/encrypt/crypt_test.go @@ -0,0 +1,84 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package encrypt + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/util/testleak" +) + +func (s *testEncryptSuite) TestSQLDecode(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + str string + passwd string + expect string + isError bool + }{ + {"", "", "", false}, + {"pingcap", "1234567890123456", "2C35B5A4ADF391", false}, + {"pingcap", "asdfjasfwefjfjkj", "351CC412605905", false}, + {"pingcap123", "123456789012345678901234", "7698723DC6DFE7724221", false}, + {"pingcap#%$%^", "*^%YTu1234567", "8634B9C55FF55E5B6328F449", false}, + {"pingcap", "", "4A77B524BD2C5C", false}, + {"分布式データベース", "pass1234@#$%%^^&", "80CADC8D328B3026D04FB285F36FED04BBCA0CC685BF78B1E687CE", false}, + {"分布式データベース", "分布式7782734adgwy1242", "0E24CFEF272EE32B6E0BFBDB89F29FB43B4B30DAA95C3F914444BC", false}, + {"pingcap", "密匙", "CE5C02A5010010", false}, + {"pingcap数据库", "数据库passwd12345667", "36D5F90D3834E30E396BE3226E3B4ED3", false}, + } + + for _, t := range tests { + crypted, err := SQLDecode(t.str, t.passwd) + if t.isError { + c.Assert(err, NotNil, Commentf("%v", t)) + continue + } + c.Assert(err, IsNil, Commentf("%v", t)) + result := toHex([]byte(crypted)) + c.Assert(result, Equals, t.expect, Commentf("%v", t)) + } +} + +func (s *testEncryptSuite) TestSQLEncode(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + str string + passwd string + expect string + isError bool + }{ + {"", "", "", false}, + {"pingcap", "1234567890123456", "pingcap", false}, + {"pingcap", "asdfjasfwefjfjkj", "pingcap", false}, + {"pingcap123", "123456789012345678901234", "pingcap123", false}, + {"pingcap#%$%^", "*^%YTu1234567", "pingcap#%$%^", false}, + {"pingcap", "", "pingcap", false}, + {"分布式データベース", "pass1234@#$%%^^&", "分布式データベース", false}, + {"分布式データベース", "分布式7782734adgwy1242", "分布式データベース", false}, + {"pingcap", "密匙", "pingcap", false}, + {"pingcap数据库", "数据库passwd12345667", "pingcap数据库", false}, + } + + for _, t := range tests { + crypted, err := SQLDecode(t.str, t.passwd) + uncrypte, err := SQLEncode(crypted, t.passwd) + + if t.isError { + c.Assert(err, NotNil, Commentf("%v", t)) + continue + } + c.Assert(err, IsNil, Commentf("%v", t)) + c.Assert(uncrypte, Equals, t.expect, Commentf("%v", t)) + } +}