diff --git a/mysql/const.go b/mysql/const.go index 637118a597088..6cd649380bdc5 100644 --- a/mysql/const.go +++ b/mysql/const.go @@ -252,3 +252,6 @@ var AllTablePrivs = []PrivilegeType{SelectPriv, InsertPriv, UpdatePriv, DeletePr // AllColumnPrivs is all the privileges in column scope. var AllColumnPrivs = []PrivilegeType{SelectPriv, InsertPriv, UpdatePriv} + +// AllPrivilegeLiteral is the string literal for All Privilege. +const AllPrivilegeLiteral = "ALL PRIVILEGES" diff --git a/parser/parser.y b/parser/parser.y index ef5ed414e8644..ffb2b0ed6299c 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -152,6 +152,7 @@ import ( ge ">=" global "GLOBAL" grant "GRANT" + grants "GRANTS" group "GROUP" groupConcat "GROUP_CONCAT" having "HAVING" @@ -1700,7 +1701,7 @@ UnReservedKeyword: | "START" | "GLOBAL" | "TABLES"| "TEXT" | "TIME" | "TIMESTAMP" | "TRANSACTION" | "TRUNCATE" | "UNKNOWN" | "VALUE" | "WARNINGS" | "YEAR" | "MODE" | "WEEK" | "ANY" | "SOME" | "USER" | "IDENTIFIED" | "COLLATION" | "COMMENT" | "AVG_ROW_LENGTH" | "CONNECTION" | "CHECKSUM" | "COMPRESSION" | "KEY_BLOCK_SIZE" | "MAX_ROWS" | "MIN_ROWS" -| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" +| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" | "GRANTS" NotKeywordToken: "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" @@ -3453,6 +3454,19 @@ ShowStmt: TableIdent: $4.(table.Ident), } } +| "SHOW" "GRANTS" + { + // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html + $$ = &stmts.ShowStmt{Target: stmt.ShowGrants} + } +| "SHOW" "GRANTS" "FOR" Username + { + // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html + $$ = &stmts.ShowStmt{ + Target: stmt.ShowGrants, + User: $4.(string), + } + } ShowLikeOrWhereOpt: { diff --git a/parser/parser_test.go b/parser/parser_test.go index 65b890047c864..50cde5ef71daf 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -58,7 +58,7 @@ func (s *testParserSuite) TestSimple(c *C) { "start", "global", "tables", "text", "time", "timestamp", "transaction", "truncate", "unknown", "value", "warnings", "year", "now", "substring", "mode", "any", "some", "user", "identified", "collation", "comment", "avg_row_length", "checksum", "compression", "connection", "key_block_size", - "max_rows", "min_rows", "national", "row", "quarter", "escape", + "max_rows", "min_rows", "national", "row", "quarter", "escape", "grants", } for _, kw := range unreservedKws { src := fmt.Sprintf("SELECT %s FROM tbl;", kw) @@ -252,6 +252,8 @@ func (s *testParserSuite) TestDMLStmt(c *C) { {"SHOW GLOBAL VARIABLES WHERE Variable_name = 'autocommit'", true}, {`SHOW FULL TABLES FROM icar_qa LIKE play_evolutions`, true}, {`SHOW FULL TABLES WHERE Table_Type != 'VIEW'`, true}, + {`SHOW GRANTS`, true}, + {`SHOW GRANTS FOR 'test'@'localhost'`, true}, // For default value {"CREATE TABLE sbtest (id INTEGER UNSIGNED NOT NULL AUTO_INCREMENT, k integer UNSIGNED DEFAULT '0' NOT NULL, c char(120) DEFAULT '' NOT NULL, pad char(60) DEFAULT '' NOT NULL, PRIMARY KEY (id) )", true}, diff --git a/parser/scanner.l b/parser/scanner.l index 8ce33d9011dd3..9bccb680644d8 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -319,6 +319,7 @@ full {f}{u}{l}{l} fulltext {f}{u}{l}{l}{t}{e}{x}{t} global {g}{l}{o}{b}{a}{l} grant {g}{r}{a}{n}{t} +grants {g}{r}{a}{n}{t}{s} group {g}{r}{o}{u}{p} group_concat {g}{r}{o}{u}{p}_{c}{o}{n}{c}{a}{t} having {h}{a}{v}{i}{n}{g} @@ -692,6 +693,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} return full {fulltext} return fulltext {grant} return grant +{grants} lval.item = string(l.val) + return grants {group} return group {group_concat} lval.item = string(l.val) return groupConcat diff --git a/plan/plans/show.go b/plan/plans/show.go index 24df7b5ec5046..9b6537f9a353e 100644 --- a/plan/plans/show.go +++ b/plan/plans/show.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" + "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/stmt" @@ -53,6 +54,7 @@ type ShowPlan struct { Where expression.Expression rows []*plan.Row cursor int + User string // ShowGrants need to know username. } func (s *ShowPlan) isColOK(c *column.Col) bool { @@ -107,6 +109,8 @@ func (s *ShowPlan) GetFields() []*field.ResultField { mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong} case stmt.ShowCreateTable: names = []string{"Table", "Create Table"} + case stmt.ShowGrants: + names = []string{fmt.Sprintf("Grants for %s", s.User)} } fields := make([]*field.ResultField, 0, len(names)) for i, name := range names { @@ -164,6 +168,8 @@ func (s *ShowPlan) fetchAll(ctx context.Context) error { return s.fetchShowCollation(ctx) case stmt.ShowCreateTable: return s.fetchShowCreateTable(ctx) + case stmt.ShowGrants: + return s.fetchShowGrants(ctx) } return nil } @@ -186,7 +192,6 @@ func (s *ShowPlan) evalCondition(ctx context.Context, m map[interface{}]interfac if cond == nil { return true, nil } - return expression.EvalBoolExpr(ctx, cond, m) } @@ -498,3 +503,20 @@ func (s *ShowPlan) fetchShowCreateTable(ctx context.Context) error { return nil } + +func (s *ShowPlan) fetchShowGrants(ctx context.Context) error { + // Get checker + checker := privilege.GetPrivilegeChecker(ctx) + if checker == nil { + return errors.New("Miss privilege checker!") + } + gs, err := checker.ShowGrants(ctx, s.User) + if err != nil { + return errors.Trace(err) + } + for _, g := range gs { + data := []interface{}{g} + s.rows = append(s.rows, &plan.Row{Data: data}) + } + return nil +} diff --git a/plan/plans/show_test.go b/plan/plans/show_test.go index 46ad9468a6b96..df6c9a425a1d1 100644 --- a/plan/plans/show_test.go +++ b/plan/plans/show_test.go @@ -15,6 +15,7 @@ package plans_test import ( "database/sql" + "fmt" . "github.com/pingcap/check" "github.com/pingcap/tidb" @@ -34,17 +35,52 @@ import ( type testShowSuit struct { txn kv.Transaction ctx context.Context + + store kv.Storage + dbName string + + createDBSQL string + dropDBSQL string + useDBSQL string + createTableSQL string + createSystemDBSQL string + createUserTableSQL string + createDBPrivTableSQL string + createTablePrivTableSQL string + createColumnPrivTableSQL string } var _ = Suite(&testShowSuit{}) func (p *testShowSuit) SetUpSuite(c *C) { - var err error - store, err := tidb.NewStore(tidb.EngineGoLevelDBMemory) - c.Assert(err, IsNil) p.ctx = mock.NewContext() - p.txn, _ = store.Begin() variable.BindSessionVars(p.ctx) + + p.dbName = "testshowplan" + p.store = newStore(c, p.dbName) + p.txn, _ = p.store.Begin() + se := newSession(c, p.store, p.dbName) + p.createDBSQL = fmt.Sprintf("create database if not exists %s;", p.dbName) + p.dropDBSQL = fmt.Sprintf("drop database if exists %s;", p.dbName) + p.useDBSQL = fmt.Sprintf("use %s;", p.dbName) + p.createTableSQL = `CREATE TABLE test(id INT NOT NULL DEFAULT 1, name varchar(255), PRIMARY KEY(id));` + + mustExecSQL(c, se, p.createDBSQL) + mustExecSQL(c, se, p.useDBSQL) + mustExecSQL(c, se, p.createTableSQL) + + p.createSystemDBSQL = fmt.Sprintf("create database if not exists %s;", mysql.SystemDB) + p.createUserTableSQL = tidb.CreateUserTable + p.createDBPrivTableSQL = tidb.CreateDBPrivTable + p.createTablePrivTableSQL = tidb.CreateTablePrivTable + p.createColumnPrivTableSQL = tidb.CreateColumnPrivTable + + mustExecSQL(c, se, p.createSystemDBSQL) + mustExecSQL(c, se, p.createUserTableSQL) + mustExecSQL(c, se, p.createDBPrivTableSQL) + mustExecSQL(c, se, p.createTablePrivTableSQL) + mustExecSQL(c, se, p.createColumnPrivTableSQL) + } func (p *testShowSuit) TearDownSuite(c *C) { @@ -213,3 +249,42 @@ func (p *testShowSuit) TestShowTables(c *C) { rows.Next() c.Assert(rows.Err(), NotNil) } + +func (p *testShowSuit) TestShowGrants(c *C) { + se := newSession(c, p.store, p.dbName) + ctx, _ := se.(context.Context) + mustExecSQL(c, se, `CREATE USER 'test'@'localhost' identified by '123';`) + variable.GetSessionVars(ctx).User = `test@localhost` + mustExecSQL(c, se, `GRANT Index ON *.* TO 'test'@'localhost';`) + + pln := &plans.ShowPlan{ + Target: stmt.ShowGrants, + User: `test@localhost`, + } + row, err := pln.Next(ctx) + c.Assert(err, IsNil) + c.Assert(row.Data[0], Equals, `GRANT Index ON *.* TO 'test'@'localhost'`) + + fs := pln.GetFields() + c.Assert(fs, HasLen, 1) + c.Assert(fs[0].Name, Equals, `Grants for test@localhost`) +} + +func mustExecSQL(c *C, se tidb.Session, sql string) { + _, err := se.Execute(sql) + c.Assert(err, IsNil) +} + +func newStore(c *C, dbPath string) kv.Storage { + store, err := tidb.NewStore("memory" + "://" + dbPath) + c.Assert(err, IsNil) + return store +} + +func newSession(c *C, store kv.Storage, dbName string) tidb.Session { + se, err := tidb.CreateSession(store) + c.Assert(err, IsNil) + mustExecSQL(c, se, "create database if not exists "+dbName) + mustExecSQL(c, se, "use "+dbName) + return se +} diff --git a/privilege/privilege.go b/privilege/privilege.go index 9a12c143a9d4e..653baedc8ae5f 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -31,6 +31,8 @@ type Checker interface { // If tbl is nil, only check global/db scope privileges. // If tbl is not nil, check global/db/table scope privileges. Check(ctx context.Context, db *model.DBInfo, tbl *model.TableInfo, privilege mysql.PrivilegeType) (bool, error) + // Show granted privileges for user. + ShowGrants(ctx context.Context, user string) ([]string, error) } const key keyType = 0 diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index 65ee50c07e27b..499d7e9ac6fcd 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/sqlexec" @@ -29,6 +30,7 @@ import ( var _ privilege.Checker = (*UserPrivileges)(nil) type privileges struct { + Level int privs map[mysql.PrivilegeType]bool } @@ -47,7 +49,72 @@ func (ps *privileges) add(p mysql.PrivilegeType) { ps.privs[p] = true } +func (ps *privileges) String() string { + switch ps.Level { + case coldef.GrantLevelGlobal: + return ps.globalPrivToString() + case coldef.GrantLevelDB: + return ps.dbPrivToString() + case coldef.GrantLevelTable: + return ps.tablePrivToString() + } + return "" +} + +func (ps *privileges) globalPrivToString() string { + if len(ps.privs) == len(mysql.AllGlobalPrivs) { + return mysql.AllPrivilegeLiteral + } + pstrs := make([]string, 0, len(ps.privs)) + // Iterate AllGlobalPrivs to get stable order result. + for _, p := range mysql.AllGlobalPrivs { + _, ok := ps.privs[p] + if !ok { + continue + } + s, _ := mysql.Priv2Str[p] + pstrs = append(pstrs, s) + } + return strings.Join(pstrs, ",") +} + +func (ps *privileges) dbPrivToString() string { + if len(ps.privs) == len(mysql.AllDBPrivs) { + return mysql.AllPrivilegeLiteral + } + pstrs := make([]string, 0, len(ps.privs)) + // Iterate AllDBPrivs to get stable order result. + for _, p := range mysql.AllDBPrivs { + _, ok := ps.privs[p] + if !ok { + continue + } + s, _ := mysql.Priv2SetStr[p] + pstrs = append(pstrs, s) + } + return strings.Join(pstrs, ",") +} + +func (ps *privileges) tablePrivToString() string { + if len(ps.privs) == len(mysql.AllTablePrivs) { + return mysql.AllPrivilegeLiteral + } + pstrs := make([]string, 0, len(ps.privs)) + // Iterate AllTablePrivs to get stable order result. + for _, p := range mysql.AllTablePrivs { + _, ok := ps.privs[p] + if !ok { + continue + } + s, _ := mysql.Priv2Str[p] + pstrs = append(pstrs, s) + } + return strings.Join(pstrs, ",") +} + type userPrivileges struct { + User string + Host string // Global privileges GlobalPrivs *privileges // DBName-privileges @@ -56,18 +123,50 @@ type userPrivileges struct { TablePrivs map[string]map[string]*privileges } +func (ps *userPrivileges) ShowGrants() []string { + gs := []string{} + // Show global grants + g := ps.GlobalPrivs.String() + if len(g) > 0 { + s := fmt.Sprintf(`GRANT %s ON *.* TO '%s'@'%s'`, g, ps.User, ps.Host) + gs = append(gs, s) + } + // Show db scope grants + for d, p := range ps.DBPrivs { + g := p.String() + if len(g) > 0 { + s := fmt.Sprintf(`GRANT %s ON %s.* TO '%s'@'%s'`, g, d, ps.User, ps.Host) + gs = append(gs, s) + } + } + // Show table scope grants + for d, dps := range ps.TablePrivs { + for t, p := range dps { + g := p.String() + if len(g) > 0 { + s := fmt.Sprintf(`GRANT %s ON %s.%s TO '%s'@'%s'`, g, d, t, ps.User, ps.Host) + gs = append(gs, s) + } + } + } + return gs +} + // UserPrivileges implements privilege.Checker interface. // This is used to check privilege for the current user. type UserPrivileges struct { - username string - host string - privs *userPrivileges + User string + privs *userPrivileges } // Check implements Checker.Check interface. func (p *UserPrivileges) Check(ctx context.Context, db *model.DBInfo, tbl *model.TableInfo, privilege mysql.PrivilegeType) (bool, error) { if p.privs == nil { // Lazy load + if len(p.User) == 0 { + // User current user + p.User = variable.GetSessionVars(ctx).User + } err := p.loadPrivileges(ctx) if err != nil { return false, errors.Trace(err) @@ -102,10 +201,15 @@ func (p *UserPrivileges) Check(ctx context.Context, db *model.DBInfo, tbl *model } func (p *UserPrivileges) loadPrivileges(ctx context.Context) error { - user := variable.GetSessionVars(ctx).User - strs := strings.Split(user, "@") - p.username, p.host = strs[0], strs[1] - p.privs = &userPrivileges{} + strs := strings.Split(p.User, "@") + if len(strs) != 2 { + return errors.Errorf("Wrong username format: %s", p.User) + } + username, host := strs[0], strs[1] + p.privs = &userPrivileges{ + User: username, + Host: host, + } // Load privileges from mysql.User/DB/Table_privs/Column_privs table err := p.loadGlobalPrivileges(ctx) if err != nil { @@ -129,13 +233,14 @@ const userTablePrivColumnStartIndex = 3 const dbTablePrivColumnStartIndex = 3 func (p *UserPrivileges) loadGlobalPrivileges(ctx context.Context) error { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, mysql.SystemDB, mysql.UserTable, p.username, p.host) + sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, + mysql.SystemDB, mysql.UserTable, p.privs.User, p.privs.Host) rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql) if err != nil { return errors.Trace(err) } defer rs.Close() - ps := &privileges{} + ps := &privileges{Level: coldef.GrantLevelGlobal} fs, err := rs.Fields() if err != nil { return errors.Trace(err) @@ -170,7 +275,8 @@ func (p *UserPrivileges) loadGlobalPrivileges(ctx context.Context) error { } func (p *UserPrivileges) loadDBScopePrivileges(ctx context.Context) error { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, mysql.SystemDB, mysql.DBTable, p.username, p.host) + sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, + mysql.SystemDB, mysql.DBTable, p.privs.User, p.privs.Host) rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql) if err != nil { return errors.Trace(err) @@ -194,7 +300,7 @@ func (p *UserPrivileges) loadDBScopePrivileges(ctx context.Context) error { if !ok { errors.New("This should be never happened!") } - ps[db] = &privileges{} + ps[db] = &privileges{Level: coldef.GrantLevelDB} for i := dbTablePrivColumnStartIndex; i < len(fs); i++ { d := row.Data[i] ed, ok := d.(mysql.Enum) @@ -217,7 +323,8 @@ func (p *UserPrivileges) loadDBScopePrivileges(ctx context.Context) error { } func (p *UserPrivileges) loadTableScopePrivileges(ctx context.Context) error { - sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, mysql.SystemDB, mysql.TablePrivTable, p.username, p.host) + sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`, + mysql.SystemDB, mysql.TablePrivTable, p.privs.User, p.privs.Host) rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql) if err != nil { return errors.Trace(err) @@ -246,7 +353,7 @@ func (p *UserPrivileges) loadTableScopePrivileges(ctx context.Context) error { if !ok { ps[db] = make(map[string]*privileges) } - ps[db][tbl] = &privileges{} + ps[db][tbl] = &privileges{Level: coldef.GrantLevelTable} // Table_priv tblPrivs, ok := row.Data[6].(mysql.Set) if !ok { @@ -264,3 +371,17 @@ func (p *UserPrivileges) loadTableScopePrivileges(ctx context.Context) error { p.privs.TablePrivs = ps return nil } + +// ShowGrants implements privilege.Checker ShowGrants interface. +func (p *UserPrivileges) ShowGrants(ctx context.Context, user string) ([]string, error) { + // If user is current user + if user == p.User { + return p.privs.ShowGrants(), nil + } + userp := &UserPrivileges{User: user} + err := userp.loadPrivileges(ctx) + if err != nil { + return nil, errors.Trace(err) + } + return userp.privs.ShowGrants(), nil +} diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index 054094a28dc29..37c74ddca5e26 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/testutil" ) func TestT(t *testing.T) { @@ -39,6 +40,7 @@ type testPrivilegeSuite struct { dbName string createDBSQL string + createDB1SQL string dropDBSQL string useDBSQL string createTableSQL string @@ -55,11 +57,13 @@ func (t *testPrivilegeSuite) SetUpTest(c *C) { t.store = newStore(c, t.dbName) se := newSession(c, t.store, t.dbName) t.createDBSQL = fmt.Sprintf("create database if not exists %s;", t.dbName) + t.createDB1SQL = fmt.Sprintf("create database if not exists %s1;", t.dbName) t.dropDBSQL = fmt.Sprintf("drop database if exists %s;", t.dbName) t.useDBSQL = fmt.Sprintf("use %s;", t.dbName) t.createTableSQL = `CREATE TABLE test(id INT NOT NULL DEFAULT 1, name varchar(255), PRIMARY KEY(id));` mustExec(c, se, t.createDBSQL) + mustExec(c, se, t.createDB1SQL) // create database test1 mustExec(c, se, t.useDBSQL) mustExec(c, se, t.createTableSQL) @@ -152,6 +156,83 @@ func (t *testPrivilegeSuite) TestCheckTablePrivilege(c *C) { c.Assert(r, IsTrue) } +func (t *testPrivilegeSuite) TestShowGrants(c *C) { + se := newSession(c, t.store, t.dbName) + ctx, _ := se.(context.Context) + mustExec(c, se, `CREATE USER 'show'@'localhost' identified by '123';`) + mustExec(c, se, `GRANT Index ON *.* TO 'show'@'localhost';`) + pc := &privileges.UserPrivileges{} + gs, err := pc.ShowGrants(ctx, `show@localhost`) + c.Assert(err, IsNil) + c.Assert(gs, HasLen, 1) + c.Assert(gs[0], Equals, `GRANT Index ON *.* TO 'show'@'localhost'`) + + mustExec(c, se, `GRANT Select ON *.* TO 'show'@'localhost';`) + pc = &privileges.UserPrivileges{} + gs, err = pc.ShowGrants(ctx, `show@localhost`) + c.Assert(err, IsNil) + c.Assert(gs, HasLen, 1) + c.Assert(gs[0], Equals, `GRANT Select,Index ON *.* TO 'show'@'localhost'`) + + // The order of privs is the same with AllGlobalPrivs + mustExec(c, se, `GRANT Update ON *.* TO 'show'@'localhost';`) + pc = &privileges.UserPrivileges{} + gs, err = pc.ShowGrants(ctx, `show@localhost`) + c.Assert(err, IsNil) + c.Assert(gs, HasLen, 1) + c.Assert(gs[0], Equals, `GRANT Select,Update,Index ON *.* TO 'show'@'localhost'`) + + // All privileges + mustExec(c, se, `GRANT ALL ON *.* TO 'show'@'localhost';`) + pc = &privileges.UserPrivileges{} + gs, err = pc.ShowGrants(ctx, `show@localhost`) + c.Assert(err, IsNil) + c.Assert(gs, HasLen, 1) + c.Assert(gs[0], Equals, `GRANT ALL PRIVILEGES ON *.* TO 'show'@'localhost'`) + + // Add db scope privileges + mustExec(c, se, `GRANT Select ON test.* TO 'show'@'localhost';`) + pc = &privileges.UserPrivileges{} + gs, err = pc.ShowGrants(ctx, `show@localhost`) + c.Assert(err, IsNil) + c.Assert(gs, HasLen, 2) + expected := []string{`GRANT ALL PRIVILEGES ON *.* TO 'show'@'localhost'`, + `GRANT Select ON test.* TO 'show'@'localhost'`} + c.Assert(testutil.CompareUnorderedStringSlice(gs, expected), IsTrue) + + mustExec(c, se, `GRANT Index ON test1.* TO 'show'@'localhost';`) + pc = &privileges.UserPrivileges{} + gs, err = pc.ShowGrants(ctx, `show@localhost`) + c.Assert(err, IsNil) + c.Assert(gs, HasLen, 3) + expected = []string{`GRANT ALL PRIVILEGES ON *.* TO 'show'@'localhost'`, + `GRANT Select ON test.* TO 'show'@'localhost'`, + `GRANT Index ON test1.* TO 'show'@'localhost'`} + c.Assert(testutil.CompareUnorderedStringSlice(gs, expected), IsTrue) + + mustExec(c, se, `GRANT ALL ON test1.* TO 'show'@'localhost';`) + pc = &privileges.UserPrivileges{} + gs, err = pc.ShowGrants(ctx, `show@localhost`) + c.Assert(err, IsNil) + c.Assert(gs, HasLen, 3) + expected = []string{`GRANT ALL PRIVILEGES ON *.* TO 'show'@'localhost'`, + `GRANT Select ON test.* TO 'show'@'localhost'`, + `GRANT ALL PRIVILEGES ON test1.* TO 'show'@'localhost'`} + c.Assert(testutil.CompareUnorderedStringSlice(gs, expected), IsTrue) + + // Add table scope privileges + mustExec(c, se, `GRANT Update ON test.test TO 'show'@'localhost';`) + pc = &privileges.UserPrivileges{} + gs, err = pc.ShowGrants(ctx, `show@localhost`) + c.Assert(err, IsNil) + c.Assert(gs, HasLen, 4) + expected = []string{`GRANT ALL PRIVILEGES ON *.* TO 'show'@'localhost'`, + `GRANT Select ON test.* TO 'show'@'localhost'`, + `GRANT ALL PRIVILEGES ON test1.* TO 'show'@'localhost'`, + `GRANT Update ON test.test TO 'show'@'localhost'`} + c.Assert(testutil.CompareUnorderedStringSlice(gs, expected), IsTrue) +} + func mustExec(c *C, se tidb.Session, sql string) { _, err := se.Execute(sql) c.Assert(err, IsNil) diff --git a/stmt/stmt.go b/stmt/stmt.go index f9f11fa8019b5..3a88118139b89 100644 --- a/stmt/stmt.go +++ b/stmt/stmt.go @@ -57,6 +57,7 @@ const ( ShowVariables ShowCollation ShowCreateTable + ShowGrants ) const ( diff --git a/stmt/stmts/show.go b/stmt/stmts/show.go index 8d9d160382196..83df25e5fd4d4 100644 --- a/stmt/stmts/show.go +++ b/stmt/stmts/show.go @@ -14,13 +14,13 @@ package stmts import ( - "github.com/ngaut/log" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/rset/rsets" "github.com/pingcap/tidb/sessionctx/db" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/stmt" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util/format" @@ -43,6 +43,9 @@ type ShowStmt struct { Pattern *expression.PatternLike Where expression.Expression + // Used by show grants + User string + Text string } @@ -68,8 +71,6 @@ func (s *ShowStmt) SetText(text string) { // Exec implements the stmt.Statement Exec interface. func (s *ShowStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { - // TODO: finish this - log.Debug("Exec Show Stmt") r := &plans.ShowPlan{ Target: s.Target, DBName: s.getDBName(ctx), @@ -80,6 +81,10 @@ func (s *ShowStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { GlobalScope: s.GlobalScope, Pattern: s.Pattern, Where: s.Where, + User: s.User, + } + if s.Target == stmt.ShowGrants && len(s.User) == 0 { + r.User = variable.GetSessionVars(ctx).User } return rsets.Recordset{Ctx: ctx, Plan: r}, nil } diff --git a/util/testutil/testutil.go b/util/testutil/testutil.go new file mode 100644 index 0000000000000..b9a29ecc051e2 --- /dev/null +++ b/util/testutil/testutil.go @@ -0,0 +1,50 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +// CompareUnorderedStringSlice compare two string slices. +// If a and b is exactly the same except the order, it returns true. +// In otherwise return false. +func CompareUnorderedStringSlice(a []string, b []string) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + if len(a) != len(b) { + return false + } + m := make(map[string]int, len(a)) + for _, i := range a { + _, ok := m[i] + if !ok { + m[i] = 1 + } else { + m[i]++ + } + } + + for _, i := range b { + _, ok := m[i] + if !ok { + return false + } + m[i]-- + if m[i] == 0 { + delete(m, i) + } + } + return len(m) == 0 +} diff --git a/util/testutil/testutil_test.go b/util/testutil/testutil_test.go new file mode 100644 index 0000000000000..11aea896de7b0 --- /dev/null +++ b/util/testutil/testutil_test.go @@ -0,0 +1,48 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "testing" + + . "github.com/pingcap/check" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testTestUtilSuite{}) + +type testTestUtilSuite struct { +} + +func (s *testTestUtilSuite) TestCompareUnorderedString(c *C) { + tbl := []struct { + a []string + b []string + r bool + }{ + {[]string{"1", "1", "2"}, []string{"1", "1", "2"}, true}, + {[]string{"1", "1", "2"}, []string{"1", "2", "1"}, true}, + {[]string{"1", "1"}, []string{"1", "2", "1"}, false}, + {[]string{"1", "1", "2"}, []string{"1", "2", "2"}, false}, + {nil, nil, true}, + {[]string{}, nil, false}, + {nil, []string{}, false}, + } + for _, t := range tbl { + c.Assert(CompareUnorderedStringSlice(t.a, t.b), Equals, t.r) + } +}