diff --git a/digester.go b/digester.go new file mode 100644 index 000000000..28911b5d7 --- /dev/null +++ b/digester.go @@ -0,0 +1,342 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package parser + +import ( + "bytes" + "crypto/sha256" + "fmt" + hash2 "hash" + "strings" + "sync" + "unicode" +) + +// DigestHash generates the digest of statements. +// it will generate a hash on normalized form of statement text +// which removes general property of a statement but keeps specific property. +// +// for example: both DigestHash('select 1') and DigestHash('select 2') => e1c71d1661ae46e09b7aaec1c390957f0d6260410df4e4bc71b9c8d681021471 +func DigestHash(sql string) (result string) { + d := digesterPool.Get().(*sqlDigester) + result = d.doDigest(sql) + digesterPool.Put(d) + return +} + +// Normalize generates the normalized statements. +// it will get normalized form of statement text +// which removes general property of a statement but keeps specific property. +// +// for example: Normalize('select 1 from b where a = 1') => 'select ? from b where a = ?' +func Normalize(sql string) (result string) { + d := digesterPool.Get().(*sqlDigester) + result = d.doDigestText(sql) + digesterPool.Put(d) + return +} + +var digesterPool = sync.Pool{ + New: func() interface{} { + return &sqlDigester{ + lexer: NewScanner(""), + hasher: sha256.New(), + } + }, +} + +// sqlDigester is used to compute DigestHash or Normalize for sql. +type sqlDigester struct { + buffer bytes.Buffer + lexer *Scanner + hasher hash2.Hash + tokens tokenDeque +} + +func (d *sqlDigester) doDigest(sql string) (result string) { + d.normalize(sql) + d.hasher.Write(d.buffer.Bytes()) + d.buffer.Reset() + result = fmt.Sprintf("%x", d.hasher.Sum(nil)) + d.hasher.Reset() + return +} + +func (d *sqlDigester) doDigestText(sql string) (result string) { + d.normalize(sql) + result = string(d.buffer.Bytes()) + d.lexer.reset("") + d.buffer.Reset() + return +} + +const ( + // genericSymbol presents parameter holder ("?") in statement + // it can be any value as long as it is not repeated with other tokens. + genericSymbol = -1 + // genericSymbolList presents parameter holder lists ("?, ?, ...") in statement + // it can be any value as long as it is not repeated with other tokens. + genericSymbolList = -2 +) + +func (d *sqlDigester) normalize(sql string) { + d.lexer.reset(sql) + for { + tok, pos, lit := d.lexer.scan() + if tok == unicode.ReplacementChar && d.lexer.r.eof() { + break + } + if pos.Offset == len(sql) { + break + } + currTok := token{tok, strings.ToLower(lit)} + + if d.reduceOptimizerHint(&currTok) { + continue + } + + d.reduceLit(&currTok) + + d.tokens = append(d.tokens, currTok) + } + d.lexer.reset("") + for i, token := range d.tokens { + d.buffer.WriteString(token.lit) + if i != len(d.tokens)-1 { + d.buffer.WriteRune(' ') + } + } + d.tokens = d.tokens[:0] +} + +func (d *sqlDigester) reduceOptimizerHint(tok *token) (reduced bool) { + // ignore /*+..*/ + if tok.tok == hintBegin { + for { + tok, _, _ := d.lexer.scan() + if tok == 0 || (tok == unicode.ReplacementChar && d.lexer.r.eof()) { + break + } + if tok == hintEnd { + reduced = true + break + } + } + return + } + + // ignore force/use/ignore index(x) + if tok.lit == "index" { + toks := d.tokens.back(1) + if len(toks) > 0 { + switch strings.ToLower(toks[0].lit) { + case "force", "use", "ignore": + for { + tok, _, lit := d.lexer.scan() + if tok == 0 || (tok == unicode.ReplacementChar && d.lexer.r.eof()) { + break + } + if lit == ")" { + reduced = true + d.tokens.popBack(1) + break + } + } + return + } + } + } + + // ignore straight_join + if tok.lit == "straight_join" { + tok.lit = "join" + return + } + return +} + +func (d *sqlDigester) reduceLit(currTok *token) { + if !d.isLit(*currTok) { + return + } + // count(*) => count(?) + if currTok.lit == "*" { + if d.isStarParam() { + currTok.tok = genericSymbol + currTok.lit = "?" + } + return + } + + // "-x" or "+x" => "x" + if d.isPrefixByUnary(currTok.tok) { + d.tokens.popBack(1) + } + + // "?, ?, ?, ?" => "..." + last2 := d.tokens.back(2) + if d.isGenericList(last2) { + d.tokens.popBack(2) + currTok.tok = genericSymbolList + currTok.lit = "..." + return + } + + // order by n => order by n + if currTok.tok == intLit { + if d.isOrderOrGroupBy() { + return + } + } + + // 2 => ? + currTok.tok = genericSymbol + currTok.lit = "?" + return +} + +func (d *sqlDigester) isPrefixByUnary(currTok int) (isUnary bool) { + if !d.isNumLit(currTok) { + return + } + last := d.tokens.back(1) + if last == nil { + return + } + // a[0] != '-' and a[0] != '+' + if last[0].lit != "-" && last[0].lit != "+" { + return + } + last2 := d.tokens.back(2) + if last2 == nil { + isUnary = true + return + } + // '(-x' or ',-x' or ',+x' or '--x' or '+-x' + switch last2[0].lit { + case "(", ",", "+", "-", ">=", "is", "<=", "=", "<", ">": + isUnary = true + default: + } + // select -x or select +x + last2Lit := strings.ToLower(last2[0].lit) + if last2Lit == "select" { + isUnary = true + } + return +} + +func (d *sqlDigester) isGenericList(last2 []token) (generic bool) { + if len(last2) < 2 { + return false + } + if !d.isComma(last2[1]) { + return false + } + switch last2[0].tok { + case genericSymbol, genericSymbolList: + generic = true + default: + } + return +} + +func (d *sqlDigester) isOrderOrGroupBy() (orderOrGroupBy bool) { + var ( + last []token + n int + ) + // skip number item lists, e.g. "order by 1, 2, 3" should NOT convert to "order by ?, ?, ?" + for n = 2; ; n += 2 { + last = d.tokens.back(n) + if len(last) < 2 { + return false + } + if !d.isComma(last[1]) { + break + } + } + // handle group by number item list surround by "()", e.g. "group by (1, 2)" should not convert to "group by (?, ?)" + if last[1].lit == "(" { + last = d.tokens.back(n + 1) + if len(last) < 2 { + return false + } + } + orderOrGroupBy = (last[0].lit == "order" || last[0].lit == "group") && last[1].lit == "by" + return +} + +func (d *sqlDigester) isStarParam() (starParam bool) { + last := d.tokens.back(1) + if last == nil { + starParam = false + return + } + starParam = last[0].lit == "(" + return +} + +func (d *sqlDigester) isLit(t token) (beLit bool) { + tok := t.tok + if d.isNumLit(tok) || tok == stringLit || tok == bitLit { + beLit = true + } else if t.lit == "*" { + beLit = true + } + return +} + +func (d *sqlDigester) isNumLit(tok int) (beNum bool) { + switch tok { + case intLit, decLit, floatLit, hexLit: + beNum = true + default: + } + return +} + +func (d *sqlDigester) isComma(tok token) (isComma bool) { + isComma = tok.lit == "," + return +} + +type token struct { + tok int + lit string +} + +type tokenDeque []token + +func (s *tokenDeque) pushBack(t token) { + *s = append(*s, t) +} + +func (s *tokenDeque) popBack(n int) (t []token) { + if len(*s) < n { + t = nil + return + } + t = (*s)[len(*s)-n:] + *s = (*s)[:len(*s)-n] + return +} + +func (s *tokenDeque) back(n int) (t []token) { + if len(*s)-n < 0 { + return + } + t = (*s)[len(*s)-n:] + return +} diff --git a/digester_test.go b/digester_test.go new file mode 100644 index 000000000..bf7e8e9db --- /dev/null +++ b/digester_test.go @@ -0,0 +1,94 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package parser_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/parser" +) + +var _ = Suite(&testSQLDigestSuite{}) + +type testSQLDigestSuite struct { +} + +func (s *testSQLDigestSuite) TestNormalize(c *C) { + tests := []struct { + input string + expect string + }{ + {"SELECT 1", "select ?"}, + {"select * from b where id = 1", "select * from b where id = ?"}, + {"select 1 from b where id in (1, 3, '3', 1, 2, 3, 4)", "select ? from b where id in ( ... )"}, + {"select 1 from b where id in (1, a, 4)", "select ? from b where id in ( ? , a , ? )"}, + {"select 1 from b order by 2", "select ? from b order by 2"}, + {"select /*+ a hint */ 1", "select ?"}, + {"select /* a hint */ 1", "select ?"}, + {"select truncate(1, 2)", "select truncate ( ... )"}, + {"select -1 + - 2 + b - c + 0.2 + (-2) from c where d in (1, -2, +3)", "select ? + ? + b - c + ? + ( ? ) from c where d in ( ... )"}, + {"select * from t where a <= -1 and b < -2 and c = -3 and c > -4 and c >= -5 and e is 1", "select * from t where a <= ? and b < ? and c = ? and c > ? and c >= ? and e is ?"}, + {"select count(a), b from t group by 2", "select count ( a ) , b from t group by 2"}, + {"select count(a), b, c from t group by 2, 3", "select count ( a ) , b , c from t group by 2 , 3"}, + {"select count(a), b, c from t group by (2, 3)", "select count ( a ) , b , c from t group by ( 2 , 3 )"}, + {"select a, b from t order by 1, 2", "select a , b from t order by 1 , 2"}, + {"select count(*) from t", "select count ( ? ) from t"}, + {"select * from t Force Index(kk)", "select * from t"}, + {"select * from t USE Index(kk)", "select * from t"}, + {"select * from t Ignore Index(kk)", "select * from t"}, + {"select * from t1 straight_join t2 on t1.id=t2.id", "select * from t1 join t2 on t1 . id = t2 . id"}, + // test syntax error, it will be checked by parser, but it should not make normalize dead loop. + {"select * from t ignore index(", "select * from t ignore index"}, + {"select /*+ ", "select "}, + } + for _, test := range tests { + actual := parser.Normalize(test.input) + c.Assert(actual, Equals, test.expect) + } +} + +func (s *testSQLDigestSuite) TestDigestHashEqForSimpleSQL(c *C) { + sqlGroups := [][]string{ + {"select * from b where id = 1", "select * from b where id = '1'", "select * from b where id =2"}, + {"select 2 from b, c where c.id > 1", "select 4 from b, c where c.id > 23"}, + {"Select 3", "select 1"}, + } + for _, sqlGroup := range sqlGroups { + var d string + for _, sql := range sqlGroup { + dig := parser.DigestHash(sql) + if d == "" { + d = dig + continue + } + c.Assert(d, Equals, dig) + } + } +} + +func (s *testSQLDigestSuite) TestDigestHashNotEqForSimpleSQL(c *C) { + sqlGroups := [][]string{ + {"select * from b where id = 1", "select a from b where id = 1", "select * from d where bid =1"}, + } + for _, sqlGroup := range sqlGroups { + var d string + for _, sql := range sqlGroup { + dig := parser.DigestHash(sql) + if d == "" { + d = dig + continue + } + c.Assert(d, Not(Equals), dig) + } + } +} diff --git a/lexer.go b/lexer.go index 432ad7c6f..a8b3510b6 100644 --- a/lexer.go +++ b/lexer.go @@ -348,6 +348,7 @@ func startWithDash(s *Scanner) (tok int, pos Pos, lit string) { return } tok = int('-') + lit = "-" s.r.inc() return }