diff --git a/executor/errors.go b/executor/errors.go index 9b5caa85ac18b..d1fea0047ecd5 100644 --- a/executor/errors.go +++ b/executor/errors.go @@ -49,6 +49,7 @@ var ( ErrTableaccessDenied = terror.ClassExecutor.New(mysql.ErrTableaccessDenied, mysql.MySQLErrName[mysql.ErrTableaccessDenied]) ErrBadDB = terror.ClassExecutor.New(mysql.ErrBadDB, mysql.MySQLErrName[mysql.ErrBadDB]) ErrWrongObject = terror.ClassExecutor.New(mysql.ErrWrongObject, mysql.MySQLErrName[mysql.ErrWrongObject]) + ErrRoleNotGranted = terror.ClassPrivilege.New(mysql.ErrRoleNotGranted, mysql.MySQLErrName[mysql.ErrRoleNotGranted]) ) func init() { diff --git a/executor/simple.go b/executor/simple.go index 18bd342f0cc88..b23c0b8d6edcf 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -81,11 +81,34 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err erro return nil case *ast.DropStatsStmt: err = e.executeDropStats(x) + case *ast.SetRoleStmt: + err = e.executeSetRole(x) } e.done = true return errors.Trace(err) } +func (e *SimpleExec) executeSetRole(s *ast.SetRoleStmt) error { + checkDup := make(map[string]*auth.RoleIdentity, len(s.RoleList)) + // Check whether RoleNameList contain duplicate role name. + for _, r := range s.RoleList { + key := r.String() + checkDup[key] = r + } + roleList := make([]*auth.RoleIdentity, 0, 10) + for _, v := range checkDup { + roleList = append(roleList, v) + } + + checker := privilege.GetPrivilegeManager(e.ctx) + ok, roleName := checker.ActiveRoles(e.ctx, roleList) + if !ok { + u := e.ctx.GetSessionVars().User + return ErrRoleNotGranted.GenWithStackByArgs(roleName, u.String()) + } + return nil +} + func (e *SimpleExec) dbAccessDenied(dbname string) error { user := e.ctx.GetSessionVars().User u := user.Username diff --git a/expression/builtin.go b/expression/builtin.go index 9a90e81ebc0e1..f779687d863f0 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -470,6 +470,7 @@ var funcs = map[string]functionClass{ // information functions ast.ConnectionID: &connectionIDFunctionClass{baseFunctionClass{ast.ConnectionID, 0, 0}}, ast.CurrentUser: ¤tUserFunctionClass{baseFunctionClass{ast.CurrentUser, 0, 0}}, + ast.CurrentRole: ¤tRoleFunctionClass{baseFunctionClass{ast.CurrentRole, 0, 0}}, ast.Database: &databaseFunctionClass{baseFunctionClass{ast.Database, 0, 0}}, // This function is a synonym for DATABASE(). // See http://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_schema diff --git a/expression/builtin_info.go b/expression/builtin_info.go index a555bd08d26d5..0653afef87847 100644 --- a/expression/builtin_info.go +++ b/expression/builtin_info.go @@ -30,6 +30,7 @@ var ( _ functionClass = &databaseFunctionClass{} _ functionClass = &foundRowsFunctionClass{} _ functionClass = ¤tUserFunctionClass{} + _ functionClass = ¤tRoleFunctionClass{} _ functionClass = &userFunctionClass{} _ functionClass = &connectionIDFunctionClass{} _ functionClass = &lastInsertIDFunctionClass{} @@ -156,6 +157,50 @@ func (b *builtinCurrentUserSig) evalString(row chunk.Row) (string, bool, error) return data.User.AuthIdentityString(), false, nil } +type currentRoleFunctionClass struct { + baseFunctionClass +} + +func (c *currentRoleFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString) + bf.tp.Flen = 64 + sig := &builtinCurrentRoleSig{bf} + return sig, nil +} + +type builtinCurrentRoleSig struct { + baseBuiltinFunc +} + +func (b *builtinCurrentRoleSig) Clone() builtinFunc { + newSig := &builtinCurrentRoleSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinCurrentUserSig. +// See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user +func (b *builtinCurrentRoleSig) evalString(row chunk.Row) (string, bool, error) { + data := b.ctx.GetSessionVars() + if data == nil || data.ActiveRoles == nil { + return "", true, errors.Errorf("Missing session variable when eval builtin") + } + if len(data.ActiveRoles) == 0 { + return "", false, nil + } + res := "" + for i, r := range data.ActiveRoles { + res += r.String() + if i != len(data.ActiveRoles)-1 { + res += "," + } + } + return res, false, nil +} + type userFunctionClass struct { baseFunctionClass } diff --git a/expression/builtin_info_test.go b/expression/builtin_info_test.go index 0cc6d0f57ed71..0f82dfef2c510 100644 --- a/expression/builtin_info_test.go +++ b/expression/builtin_info_test.go @@ -95,6 +95,22 @@ func (s *testEvaluatorSuite) TestCurrentUser(c *C) { c.Assert(d.GetString(), Equals, "root@localhost") } +func (s *testEvaluatorSuite) TestCurrentRole(c *C) { + defer testleak.AfterTest(c)() + ctx := mock.NewContext() + sessionVars := ctx.GetSessionVars() + sessionVars.ActiveRoles = make([]*auth.RoleIdentity, 0, 10) + sessionVars.ActiveRoles = append(sessionVars.ActiveRoles, &auth.RoleIdentity{Username: "r_1", Hostname: "%"}) + sessionVars.ActiveRoles = append(sessionVars.ActiveRoles, &auth.RoleIdentity{Username: "r_2", Hostname: "localhost"}) + + fc := funcs[ast.CurrentRole] + f, err := fc.getFunction(ctx, nil) + c.Assert(err, IsNil) + d, err := evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(d.GetString(), Equals, "`r_1`@`%`,`r_2`@`localhost`") +} + func (s *testEvaluatorSuite) TestConnectionID(c *C) { defer testleak.AfterTest(c)() ctx := mock.NewContext() diff --git a/expression/function_traits.go b/expression/function_traits.go index d7c0ec881f8bd..988c020b18528 100644 --- a/expression/function_traits.go +++ b/expression/function_traits.go @@ -21,6 +21,7 @@ import ( var UnCacheableFunctions = map[string]struct{}{ ast.Database: {}, ast.CurrentUser: {}, + ast.CurrentRole: {}, ast.User: {}, ast.ConnectionID: {}, ast.LastInsertId: {}, diff --git a/go.mod b/go.mod index eefe31774a0db..fcd39ba6f7103 100644 --- a/go.mod +++ b/go.mod @@ -51,7 +51,7 @@ require ( github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e github.com/pingcap/kvproto v0.0.0-20190215154024-7f2fc73ef562 github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596 - github.com/pingcap/parser v0.0.0-20190312024907-3f6280b08c8b + github.com/pingcap/parser v0.0.0-20190320053247-fe243e3280cf github.com/pingcap/pd v2.1.0-rc.4+incompatible github.com/pingcap/tidb-tools v2.1.3-0.20190116051332-34c808eef588+incompatible github.com/pingcap/tipb v0.0.0-20190107072121-abbec73437b7 diff --git a/go.sum b/go.sum index d176369b6b27a..f77fa64d84138 100644 --- a/go.sum +++ b/go.sum @@ -119,8 +119,8 @@ github.com/pingcap/kvproto v0.0.0-20190215154024-7f2fc73ef562 h1:32oF1/8lVnBR2JV github.com/pingcap/kvproto v0.0.0-20190215154024-7f2fc73ef562/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY= github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596 h1:t2OQTpPJnrPDGlvA+3FwJptMTt6MEPdzK1Wt99oaefQ= github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw= -github.com/pingcap/parser v0.0.0-20190312024907-3f6280b08c8b h1:NlvTrxqezIJh6CD5Leky12IZ8E/GtpEEmzgNNb34wbw= -github.com/pingcap/parser v0.0.0-20190312024907-3f6280b08c8b/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= +github.com/pingcap/parser v0.0.0-20190320053247-fe243e3280cf h1:yxK78TmeSK3BIm8Z8SwdZLVzRpY80HZe1VMlA2dL648= +github.com/pingcap/parser v0.0.0-20190320053247-fe243e3280cf/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= github.com/pingcap/pd v2.1.0-rc.4+incompatible h1:/buwGk04aHO5odk/+O8ZOXGs4qkUjYTJ2UpCJXna8NE= github.com/pingcap/pd v2.1.0-rc.4+incompatible/go.mod h1:nD3+EoYes4+aNNODO99ES59V83MZSI+dFbhyr667a0E= github.com/pingcap/tidb-tools v2.1.3-0.20190116051332-34c808eef588+incompatible h1:e9Gi/LP9181HT3gBfSOeSBA+5JfemuE4aEAhqNgoE4k= diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index f32e26239cac2..ddc35c3ea43db 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -212,8 +212,8 @@ func (b *PlanBuilder) Build(node ast.Node) (Plan, error) { case *ast.AnalyzeTableStmt: return b.buildAnalyze(x) case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt, - *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, - *ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt: + *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.GrantStmt, + *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt, *ast.SetRoleStmt: return b.buildSimple(node.(ast.StmtNode)) case ast.DDLNode: return b.buildDDL(x) diff --git a/privilege/privilege.go b/privilege/privilege.go index 7091f96effee3..b6aa136c6c806 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -52,6 +52,10 @@ type Manager interface { // UserPrivilegesTable provide data for INFORMATION_SCHEMA.USERS_PRIVILEGE table. UserPrivilegesTable() [][]types.Datum + + // ActiveRoles active roles for current session. + // The first illegal role will be returned. + ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string) } const key keyType = 0 diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index 0e5bb5bf5fb4c..2731f6f39dad1 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/auth" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx" @@ -104,12 +105,42 @@ type columnsPrivRecord struct { patTypes []byte } +// RoleGraphEdgesTable is used to cache relationship between and role. +type roleGraphEdgesTable struct { + roleList map[string]bool +} + +// Find method is used to find role from table +func (g roleGraphEdgesTable) Find(user, host string) bool { + if host == "" { + host = "%" + } + key := user + "@" + host + if g.roleList == nil { + return false + } + _, ok := g.roleList[key] + return ok +} + // MySQLPrivilege is the in-memory cache of mysql privilege tables. type MySQLPrivilege struct { User []UserRecord DB []dbRecord TablesPriv []tablesPrivRecord ColumnsPriv []columnsPrivRecord + RoleGraph map[string]roleGraphEdgesTable +} + +// FindRole is used to detect whether there is edges between users and roles. +func (p *MySQLPrivilege) FindRole(user string, host string, role *auth.RoleIdentity) bool { + rec := p.matchUser(user, host) + r := p.matchUser(role.Username, role.Hostname) + if rec != nil && r != nil { + key := rec.User + "@" + rec.Host + return p.RoleGraph[key].Find(role.Username, role.Hostname) + } + return false } // LoadAll loads the tables from database to memory. @@ -142,6 +173,14 @@ func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error { } log.Warn("mysql.columns_priv missing") } + + err = p.LoadRoleGraph(ctx) + if err != nil { + if !noSuchTable(err) { + return errors.Trace(err) + } + log.Warn("mysql.role_edges missing") + } return nil } @@ -155,6 +194,16 @@ func noSuchTable(err error) bool { return false } +// LoadRoleGraph loads the mysql.role_edges table from database. +func (p *MySQLPrivilege) LoadRoleGraph(ctx sessionctx.Context) error { + p.RoleGraph = make(map[string]roleGraphEdgesTable) + err := p.loadTable(ctx, "select FROM_USER, FROM_HOST, TO_USER, TO_HOST from mysql.role_edges;", p.decodeRoleEdgesTable) + if err != nil { + return errors.Trace(err) + } + return nil +} + // LoadUserTable loads the mysql.user table from database. func (p *MySQLPrivilege) LoadUserTable(ctx sessionctx.Context) error { err := p.loadTable(ctx, "select HIGH_PRIORITY Host,User,Password,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Process_priv,Grant_priv,References_priv,Alter_priv,Show_db_priv,Super_priv,Execute_priv,Create_view_priv,Show_view_priv,Index_priv,Create_user_priv,Trigger_priv,Create_role_priv,Drop_role_priv,account_locked from mysql.user;", p.decodeUserTableRow) @@ -381,6 +430,31 @@ func (p *MySQLPrivilege) decodeTablesPrivTableRow(row chunk.Row, fs []*ast.Resul return nil } +func (p *MySQLPrivilege) decodeRoleEdgesTable(row chunk.Row, fs []*ast.ResultField) error { + var fromUser, fromHost, toHost, toUser string + for i, f := range fs { + switch { + case f.ColumnAsName.L == "from_host": + fromHost = row.GetString(i) + case f.ColumnAsName.L == "from_user": + fromUser = row.GetString(i) + case f.ColumnAsName.L == "to_host": + toHost = row.GetString(i) + case f.ColumnAsName.L == "to_user": + toUser = row.GetString(i) + } + } + fromKey := fromUser + "@" + fromHost + toKey := toUser + "@" + toHost + roleGraph, ok := p.RoleGraph[toKey] + if !ok { + roleGraph = roleGraphEdgesTable{roleList: make(map[string]bool)} + p.RoleGraph[toKey] = roleGraph + } + roleGraph.roleList[fromKey] = true + return nil +} + func (p *MySQLPrivilege) decodeColumnsPrivTableRow(row chunk.Row, fs []*ast.ResultField) error { var value columnsPrivRecord for i, f := range fs { diff --git a/privilege/privileges/cache_test.go b/privilege/privileges/cache_test.go index ff748750f065b..032b3586ee120 100644 --- a/privilege/privileges/cache_test.go +++ b/privilege/privileges/cache_test.go @@ -185,6 +185,36 @@ func (s *testCacheSuite) TestCaseInsensitive(c *C) { c.Assert(p.RequestVerification("genius", "127.0.0.1", "tctrain", "tctrainorder", "", mysql.SelectPriv), IsTrue) } +func (s *testCacheSuite) TestLoadRoleGraph(c *C) { + se, err := session.CreateSession4Test(s.store) + c.Assert(err, IsNil) + defer se.Close() + mustExec(c, se, "use mysql;") + mustExec(c, se, "truncate table user;") + + var p privileges.MySQLPrivilege + err = p.LoadRoleGraph(se) + c.Assert(err, IsNil) + c.Assert(len(p.User), Equals, 0) + + mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_1", "%", "user2")`) + mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_2", "%", "root")`) + mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_3", "%", "user1")`) + mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_4", "%", "root")`) + + p = privileges.MySQLPrivilege{} + err = p.LoadRoleGraph(se) + c.Assert(err, IsNil) + graph := p.RoleGraph + c.Assert(graph["root@%"].Find("r_2", "%"), Equals, true) + c.Assert(graph["root@%"].Find("r_4", "%"), Equals, true) + c.Assert(graph["user2@%"].Find("r_1", "%"), Equals, true) + c.Assert(graph["user1@%"].Find("r_3", "%"), Equals, true) + _, ok := graph["illedal"] + c.Assert(ok, Equals, false) + c.Assert(graph["root@%"].Find("r_1", "%"), Equals, false) +} + func (s *testCacheSuite) TestAbnormalMySQLTable(c *C) { store, err := mockstore.NewMockTikvStore() c.Assert(err, IsNil) diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index 807359ee79594..d5ba12d757c9b 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -186,3 +186,19 @@ func (p *UserPrivileges) ShowGrants(ctx sessionctx.Context, user *auth.UserIdent return } + +// ActiveRoles implements privilege.Manager ActiveRoles interface. +func (p *UserPrivileges) ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string) { + mysqlPrivilege := p.Handle.Get() + u := p.user + h := p.host + for _, r := range roleList { + ok := mysqlPrivilege.FindRole(u, h, r) + if !ok { + log.Errorf("Role: %+v doesn't grant for user", r) + return false, r.String() + } + } + ctx.GetSessionVars().ActiveRoles = roleList + return true, "" +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 0fa3773e40d4f..bbffa19c577d6 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -191,6 +191,9 @@ type SessionVars struct { // params for prepared statements PreparedParams []types.Datum + // ActiveRoles stores active roles for current user + ActiveRoles []*auth.RoleIdentity + // retry information RetryInfo *RetryInfo // Should be reset on transaction finished. @@ -350,6 +353,7 @@ func NewSessionVars() *SessionVars { TxnCtx: &TransactionContext{}, KVVars: kv.NewVariables(), RetryInfo: &RetryInfo{}, + ActiveRoles: make([]*auth.RoleIdentity, 0, 10), StrictSQLMode: true, Status: mysql.ServerStatusAutocommit, StmtCtx: new(stmtctx.StatementContext),