Skip to content

Commit

Permalink
Add two buildin function ( decode and encode) (#7622)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanjingquan authored and zz-jason committed Sep 10, 2018
1 parent d7d1309 commit 8b1feeb
Show file tree
Hide file tree
Showing 4 changed files with 344 additions and 2 deletions.
74 changes: 72 additions & 2 deletions expression/builtin_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,42 @@ type decodeFunctionClass struct {
}

func (c *decodeFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
return nil, errFunctionNotExists.GenByArgs("FUNCTION", "DECODE")
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString)

bf.tp.Flen = args[0].GetType().Flen
sig := &builtinDecodeSig{bf}
return sig, nil
}

type builtinDecodeSig struct {
baseBuiltinFunc
}

func (b *builtinDecodeSig) Clone() builtinFunc {
newSig := &builtinDecodeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalString evals DECODE(str, password_str).
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_decode
func (b *builtinDecodeSig) evalString(row chunk.Row) (string, bool, error) {
dataStr, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}

passwordStr, isNull, err := b.args[1].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}

decodeStr, err := encrypt.SQLDecode(dataStr, passwordStr)
return decodeStr, false, err
}

type desDecryptFunctionClass struct {
Expand All @@ -198,7 +233,42 @@ type encodeFunctionClass struct {
}

func (c *encodeFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
return nil, errFunctionNotExists.GenByArgs("FUNCTION", "ENCODE")
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString)

bf.tp.Flen = args[0].GetType().Flen
sig := &builtinEncodeSig{bf}
return sig, nil
}

type builtinEncodeSig struct {
baseBuiltinFunc
}

func (b *builtinEncodeSig) Clone() builtinFunc {
newSig := &builtinEncodeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalString evals ENCODE(crypt_str, password_str).
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_encode
func (b *builtinEncodeSig) evalString(row chunk.Row) (string, bool, error) {
decodeStr, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}

passwordStr, isNull, err := b.args[1].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}

dataStr, err := encrypt.SQLEncode(decodeStr, passwordStr)
return dataStr, false, err
}

type encryptFunctionClass struct {
Expand Down
50 changes: 50 additions & 0 deletions expression/builtin_encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,56 @@ import (
"github.com/pingcap/tidb/util/testleak"
)

var cryptTests = []struct {
origin interface{}
password interface{}
crypt interface{}
}{
{"", "", ""},
{"pingcap", "1234567890123456", "2C35B5A4ADF391"},
{"pingcap", "asdfjasfwefjfjkj", "351CC412605905"},
{"pingcap123", "123456789012345678901234", "7698723DC6DFE7724221"},
{"pingcap#%$%^", "*^%YTu1234567", "8634B9C55FF55E5B6328F449"},
{"pingcap", "", "4A77B524BD2C5C"},
{"分布式データベース", "pass1234@#$%%^^&", "80CADC8D328B3026D04FB285F36FED04BBCA0CC685BF78B1E687CE"},
{"分布式データベース", "分布式7782734adgwy1242", "0E24CFEF272EE32B6E0BFBDB89F29FB43B4B30DAA95C3F914444BC"},
{"pingcap", "密匙", "CE5C02A5010010"},
{"pingcap数据库", "数据库passwd12345667", "36D5F90D3834E30E396BE3226E3B4ED3"},
{"数据库5667", 123.435, "B22196D0569386237AE12F8AAB"},
{nil, "数据库passwd12345667", nil},
}

func (s *testEvaluatorSuite) TestSQLDecode(c *C) {
defer testleak.AfterTest(c)()
fc := funcs[ast.Decode]
for _, tt := range cryptTests {
str := types.NewDatum(tt.origin)
password := types.NewDatum(tt.password)

f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{str, password}))
crypt, err := evalBuiltinFunc(f, chunk.Row{})
c.Assert(err, IsNil)
c.Assert(toHex(crypt), DeepEquals, types.NewDatum(tt.crypt))
}
s.testNullInput(c, ast.Decode)
}

func (s *testEvaluatorSuite) TestSQLEncode(c *C) {
defer testleak.AfterTest(c)()
fc := funcs[ast.Encode]
for _, test := range cryptTests {
password := types.NewDatum(test.password)
cryptStr := fromHex(test.crypt)

f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{cryptStr, password}))
str, err := evalBuiltinFunc(f, chunk.Row{})

c.Assert(err, IsNil)
c.Assert(str, DeepEquals, types.NewDatum(test.origin))
}
s.testNullInput(c, ast.Encode)
}

var aesTests = []struct {
origin interface{}
key interface{}
Expand Down
138 changes: 138 additions & 0 deletions util/encrypt/crypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package encrypt

type randStruct struct {
seed1 uint32
seed2 uint32
maxValue uint32
maxValueDbl float64
}

// randomInit random generation structure initialization
func (rs *randStruct) randomInit(password []byte, length int) {
// Generate binary hash from raw text string
var nr, add, nr2, tmp uint32
nr = 1345345333
add = 7
nr2 = 0x12345671

for i := 0; i < length; i++ {
pswChar := password[i]
if pswChar == ' ' || pswChar == '\t' {
continue
}
tmp = uint32(pswChar)
nr ^= (((nr & 63) + add) * tmp) + (nr << 8)
nr2 += (nr2 << 8) ^ nr
add += tmp
}

seed1 := nr & ((uint32(1) << 31) - uint32(1))
seed2 := nr2 & ((uint32(1) << 31) - uint32(1))

// New (MySQL 3.21+) random generation structure initialization
rs.maxValue = 0x3FFFFFFF
rs.maxValueDbl = float64(rs.maxValue)
rs.seed1 = seed1 % rs.maxValue
rs.seed2 = seed2 % rs.maxValue
}

func (rs *randStruct) myRand() float64 {
rs.seed1 = (rs.seed1*3 + rs.seed2) % rs.maxValue
rs.seed2 = (rs.seed1 + rs.seed2 + 33) % rs.maxValue

return ((float64(rs.seed1)) / rs.maxValueDbl)
}

// sqlCrypt use to store initialization results
type sqlCrypt struct {
rand randStruct
orgRand randStruct

decodeBuff [256]byte
encodeBuff [256]byte
shift uint32
}

func (sc *sqlCrypt) init(password []byte, length int) {
sc.rand.randomInit(password, length)

for i := 0; i <= 255; i++ {
sc.decodeBuff[i] = byte(i)
}

for i := 0; i <= 255; i++ {
idx := uint32(sc.rand.myRand() * 255.0)
a := sc.decodeBuff[idx]
sc.decodeBuff[idx] = sc.decodeBuff[i]
sc.decodeBuff[i] = a
}

for i := 0; i <= 255; i++ {
sc.encodeBuff[sc.decodeBuff[i]] = byte(i)
}

sc.orgRand = sc.rand
sc.shift = 0
}

func (sc *sqlCrypt) reinit() {
sc.shift = 0
sc.rand = sc.orgRand
}

func (sc *sqlCrypt) encode(str []byte, length int) {
for i := 0; i < length; i++ {
sc.shift ^= uint32(sc.rand.myRand() * 255.0)
idx := uint32(str[i])
str[i] = sc.encodeBuff[idx] ^ byte(sc.shift)
sc.shift ^= idx
}
}

func (sc *sqlCrypt) decode(str []byte, length int) {
for i := 0; i < length; i++ {
sc.shift ^= uint32(sc.rand.myRand() * 255.0)
idx := uint32(str[i] ^ byte(sc.shift))
str[i] = sc.decodeBuff[idx]
sc.shift ^= uint32(str[i])
}
}

//SQLDecode Function to handle the decode() function
func SQLDecode(str string, password string) (string, error) {
var sc sqlCrypt

strByte := []byte(str)
passwdByte := []byte(password)

sc.init(passwdByte, len(passwdByte))
sc.decode(strByte, len(strByte))

return string(strByte), nil
}

// SQLEncode Function to handle the encode() function
func SQLEncode(cryptStr string, password string) (string, error) {
var sc sqlCrypt

cryptStrByte := []byte(cryptStr)
passwdByte := []byte(password)

sc.init(passwdByte, len(passwdByte))
sc.encode(cryptStrByte, len(cryptStrByte))

return string(cryptStrByte), nil
}
84 changes: 84 additions & 0 deletions util/encrypt/crypt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package encrypt

import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/util/testleak"
)

func (s *testEncryptSuite) TestSQLDecode(c *C) {
defer testleak.AfterTest(c)()
tests := []struct {
str string
passwd string
expect string
isError bool
}{
{"", "", "", false},
{"pingcap", "1234567890123456", "2C35B5A4ADF391", false},
{"pingcap", "asdfjasfwefjfjkj", "351CC412605905", false},
{"pingcap123", "123456789012345678901234", "7698723DC6DFE7724221", false},
{"pingcap#%$%^", "*^%YTu1234567", "8634B9C55FF55E5B6328F449", false},
{"pingcap", "", "4A77B524BD2C5C", false},
{"分布式データベース", "pass1234@#$%%^^&", "80CADC8D328B3026D04FB285F36FED04BBCA0CC685BF78B1E687CE", false},
{"分布式データベース", "分布式7782734adgwy1242", "0E24CFEF272EE32B6E0BFBDB89F29FB43B4B30DAA95C3F914444BC", false},
{"pingcap", "密匙", "CE5C02A5010010", false},
{"pingcap数据库", "数据库passwd12345667", "36D5F90D3834E30E396BE3226E3B4ED3", false},
}

for _, t := range tests {
crypted, err := SQLDecode(t.str, t.passwd)
if t.isError {
c.Assert(err, NotNil, Commentf("%v", t))
continue
}
c.Assert(err, IsNil, Commentf("%v", t))
result := toHex([]byte(crypted))
c.Assert(result, Equals, t.expect, Commentf("%v", t))
}
}

func (s *testEncryptSuite) TestSQLEncode(c *C) {
defer testleak.AfterTest(c)()
tests := []struct {
str string
passwd string
expect string
isError bool
}{
{"", "", "", false},
{"pingcap", "1234567890123456", "pingcap", false},
{"pingcap", "asdfjasfwefjfjkj", "pingcap", false},
{"pingcap123", "123456789012345678901234", "pingcap123", false},
{"pingcap#%$%^", "*^%YTu1234567", "pingcap#%$%^", false},
{"pingcap", "", "pingcap", false},
{"分布式データベース", "pass1234@#$%%^^&", "分布式データベース", false},
{"分布式データベース", "分布式7782734adgwy1242", "分布式データベース", false},
{"pingcap", "密匙", "pingcap", false},
{"pingcap数据库", "数据库passwd12345667", "pingcap数据库", false},
}

for _, t := range tests {
crypted, err := SQLDecode(t.str, t.passwd)
uncrypte, err := SQLEncode(crypted, t.passwd)

if t.isError {
c.Assert(err, NotNil, Commentf("%v", t))
continue
}
c.Assert(err, IsNil, Commentf("%v", t))
c.Assert(uncrypte, Equals, t.expect, Commentf("%v", t))
}
}

0 comments on commit 8b1feeb

Please sign in to comment.