Skip to content

Commit

Permalink
privilege, executor: add SET ROLE and CURRENT_ROLE support (#9581)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lingyu Song authored Mar 21, 2019
1 parent e829920 commit 778c3f4
Show file tree
Hide file tree
Showing 14 changed files with 220 additions and 5 deletions.
1 change: 1 addition & 0 deletions executor/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ var (
ErrTableaccessDenied = terror.ClassExecutor.New(mysql.ErrTableaccessDenied, mysql.MySQLErrName[mysql.ErrTableaccessDenied])
ErrBadDB = terror.ClassExecutor.New(mysql.ErrBadDB, mysql.MySQLErrName[mysql.ErrBadDB])
ErrWrongObject = terror.ClassExecutor.New(mysql.ErrWrongObject, mysql.MySQLErrName[mysql.ErrWrongObject])
ErrRoleNotGranted = terror.ClassPrivilege.New(mysql.ErrRoleNotGranted, mysql.MySQLErrName[mysql.ErrRoleNotGranted])
)

func init() {
Expand Down
23 changes: 23 additions & 0 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,34 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err erro
return nil
case *ast.DropStatsStmt:
err = e.executeDropStats(x)
case *ast.SetRoleStmt:
err = e.executeSetRole(x)
}
e.done = true
return errors.Trace(err)
}

func (e *SimpleExec) executeSetRole(s *ast.SetRoleStmt) error {
checkDup := make(map[string]*auth.RoleIdentity, len(s.RoleList))
// Check whether RoleNameList contain duplicate role name.
for _, r := range s.RoleList {
key := r.String()
checkDup[key] = r
}
roleList := make([]*auth.RoleIdentity, 0, 10)
for _, v := range checkDup {
roleList = append(roleList, v)
}

checker := privilege.GetPrivilegeManager(e.ctx)
ok, roleName := checker.ActiveRoles(e.ctx, roleList)
if !ok {
u := e.ctx.GetSessionVars().User
return ErrRoleNotGranted.GenWithStackByArgs(roleName, u.String())
}
return nil
}

func (e *SimpleExec) dbAccessDenied(dbname string) error {
user := e.ctx.GetSessionVars().User
u := user.Username
Expand Down
1 change: 1 addition & 0 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ var funcs = map[string]functionClass{
// information functions
ast.ConnectionID: &connectionIDFunctionClass{baseFunctionClass{ast.ConnectionID, 0, 0}},
ast.CurrentUser: &currentUserFunctionClass{baseFunctionClass{ast.CurrentUser, 0, 0}},
ast.CurrentRole: &currentRoleFunctionClass{baseFunctionClass{ast.CurrentRole, 0, 0}},
ast.Database: &databaseFunctionClass{baseFunctionClass{ast.Database, 0, 0}},
// This function is a synonym for DATABASE().
// See http://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_schema
Expand Down
45 changes: 45 additions & 0 deletions expression/builtin_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var (
_ functionClass = &databaseFunctionClass{}
_ functionClass = &foundRowsFunctionClass{}
_ functionClass = &currentUserFunctionClass{}
_ functionClass = &currentRoleFunctionClass{}
_ functionClass = &userFunctionClass{}
_ functionClass = &connectionIDFunctionClass{}
_ functionClass = &lastInsertIDFunctionClass{}
Expand Down Expand Up @@ -156,6 +157,50 @@ func (b *builtinCurrentUserSig) evalString(row chunk.Row) (string, bool, error)
return data.User.AuthIdentityString(), false, nil
}

type currentRoleFunctionClass struct {
baseFunctionClass
}

func (c *currentRoleFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString)
bf.tp.Flen = 64
sig := &builtinCurrentRoleSig{bf}
return sig, nil
}

type builtinCurrentRoleSig struct {
baseBuiltinFunc
}

func (b *builtinCurrentRoleSig) Clone() builtinFunc {
newSig := &builtinCurrentRoleSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalString evals a builtinCurrentUserSig.
// See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user
func (b *builtinCurrentRoleSig) evalString(row chunk.Row) (string, bool, error) {
data := b.ctx.GetSessionVars()
if data == nil || data.ActiveRoles == nil {
return "", true, errors.Errorf("Missing session variable when eval builtin")
}
if len(data.ActiveRoles) == 0 {
return "", false, nil
}
res := ""
for i, r := range data.ActiveRoles {
res += r.String()
if i != len(data.ActiveRoles)-1 {
res += ","
}
}
return res, false, nil
}

type userFunctionClass struct {
baseFunctionClass
}
Expand Down
16 changes: 16 additions & 0 deletions expression/builtin_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ func (s *testEvaluatorSuite) TestCurrentUser(c *C) {
c.Assert(d.GetString(), Equals, "root@localhost")
}

func (s *testEvaluatorSuite) TestCurrentRole(c *C) {
defer testleak.AfterTest(c)()
ctx := mock.NewContext()
sessionVars := ctx.GetSessionVars()
sessionVars.ActiveRoles = make([]*auth.RoleIdentity, 0, 10)
sessionVars.ActiveRoles = append(sessionVars.ActiveRoles, &auth.RoleIdentity{Username: "r_1", Hostname: "%"})
sessionVars.ActiveRoles = append(sessionVars.ActiveRoles, &auth.RoleIdentity{Username: "r_2", Hostname: "localhost"})

fc := funcs[ast.CurrentRole]
f, err := fc.getFunction(ctx, nil)
c.Assert(err, IsNil)
d, err := evalBuiltinFunc(f, chunk.Row{})
c.Assert(err, IsNil)
c.Assert(d.GetString(), Equals, "`r_1`@`%`,`r_2`@`localhost`")
}

func (s *testEvaluatorSuite) TestConnectionID(c *C) {
defer testleak.AfterTest(c)()
ctx := mock.NewContext()
Expand Down
1 change: 1 addition & 0 deletions expression/function_traits.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
var UnCacheableFunctions = map[string]struct{}{
ast.Database: {},
ast.CurrentUser: {},
ast.CurrentRole: {},
ast.User: {},
ast.ConnectionID: {},
ast.LastInsertId: {},
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ require (
github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e
github.com/pingcap/kvproto v0.0.0-20190215154024-7f2fc73ef562
github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596
github.com/pingcap/parser v0.0.0-20190312024907-3f6280b08c8b
github.com/pingcap/parser v0.0.0-20190320053247-fe243e3280cf
github.com/pingcap/pd v2.1.0-rc.4+incompatible
github.com/pingcap/tidb-tools v2.1.3-0.20190116051332-34c808eef588+incompatible
github.com/pingcap/tipb v0.0.0-20190107072121-abbec73437b7
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ github.com/pingcap/kvproto v0.0.0-20190215154024-7f2fc73ef562 h1:32oF1/8lVnBR2JV
github.com/pingcap/kvproto v0.0.0-20190215154024-7f2fc73ef562/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY=
github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596 h1:t2OQTpPJnrPDGlvA+3FwJptMTt6MEPdzK1Wt99oaefQ=
github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw=
github.com/pingcap/parser v0.0.0-20190312024907-3f6280b08c8b h1:NlvTrxqezIJh6CD5Leky12IZ8E/GtpEEmzgNNb34wbw=
github.com/pingcap/parser v0.0.0-20190312024907-3f6280b08c8b/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/parser v0.0.0-20190320053247-fe243e3280cf h1:yxK78TmeSK3BIm8Z8SwdZLVzRpY80HZe1VMlA2dL648=
github.com/pingcap/parser v0.0.0-20190320053247-fe243e3280cf/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/pd v2.1.0-rc.4+incompatible h1:/buwGk04aHO5odk/+O8ZOXGs4qkUjYTJ2UpCJXna8NE=
github.com/pingcap/pd v2.1.0-rc.4+incompatible/go.mod h1:nD3+EoYes4+aNNODO99ES59V83MZSI+dFbhyr667a0E=
github.com/pingcap/tidb-tools v2.1.3-0.20190116051332-34c808eef588+incompatible h1:e9Gi/LP9181HT3gBfSOeSBA+5JfemuE4aEAhqNgoE4k=
Expand Down
4 changes: 2 additions & 2 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ func (b *PlanBuilder) Build(node ast.Node) (Plan, error) {
case *ast.AnalyzeTableStmt:
return b.buildAnalyze(x)
case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt,
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt,
*ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt:
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.GrantStmt,
*ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt, *ast.SetRoleStmt:
return b.buildSimple(node.(ast.StmtNode))
case ast.DDLNode:
return b.buildDDL(x)
Expand Down
4 changes: 4 additions & 0 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ type Manager interface {

// UserPrivilegesTable provide data for INFORMATION_SCHEMA.USERS_PRIVILEGE table.
UserPrivilegesTable() [][]types.Datum

// ActiveRoles active roles for current session.
// The first illegal role will be returned.
ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string)
}

const key keyType = 0
Expand Down
74 changes: 74 additions & 0 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/sessionctx"
Expand Down Expand Up @@ -104,12 +105,42 @@ type columnsPrivRecord struct {
patTypes []byte
}

// RoleGraphEdgesTable is used to cache relationship between and role.
type roleGraphEdgesTable struct {
roleList map[string]bool
}

// Find method is used to find role from table
func (g roleGraphEdgesTable) Find(user, host string) bool {
if host == "" {
host = "%"
}
key := user + "@" + host
if g.roleList == nil {
return false
}
_, ok := g.roleList[key]
return ok
}

// MySQLPrivilege is the in-memory cache of mysql privilege tables.
type MySQLPrivilege struct {
User []UserRecord
DB []dbRecord
TablesPriv []tablesPrivRecord
ColumnsPriv []columnsPrivRecord
RoleGraph map[string]roleGraphEdgesTable
}

// FindRole is used to detect whether there is edges between users and roles.
func (p *MySQLPrivilege) FindRole(user string, host string, role *auth.RoleIdentity) bool {
rec := p.matchUser(user, host)
r := p.matchUser(role.Username, role.Hostname)
if rec != nil && r != nil {
key := rec.User + "@" + rec.Host
return p.RoleGraph[key].Find(role.Username, role.Hostname)
}
return false
}

// LoadAll loads the tables from database to memory.
Expand Down Expand Up @@ -142,6 +173,14 @@ func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error {
}
log.Warn("mysql.columns_priv missing")
}

err = p.LoadRoleGraph(ctx)
if err != nil {
if !noSuchTable(err) {
return errors.Trace(err)
}
log.Warn("mysql.role_edges missing")
}
return nil
}

Expand All @@ -155,6 +194,16 @@ func noSuchTable(err error) bool {
return false
}

// LoadRoleGraph loads the mysql.role_edges table from database.
func (p *MySQLPrivilege) LoadRoleGraph(ctx sessionctx.Context) error {
p.RoleGraph = make(map[string]roleGraphEdgesTable)
err := p.loadTable(ctx, "select FROM_USER, FROM_HOST, TO_USER, TO_HOST from mysql.role_edges;", p.decodeRoleEdgesTable)
if err != nil {
return errors.Trace(err)
}
return nil
}

// LoadUserTable loads the mysql.user table from database.
func (p *MySQLPrivilege) LoadUserTable(ctx sessionctx.Context) error {
err := p.loadTable(ctx, "select HIGH_PRIORITY Host,User,Password,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Process_priv,Grant_priv,References_priv,Alter_priv,Show_db_priv,Super_priv,Execute_priv,Create_view_priv,Show_view_priv,Index_priv,Create_user_priv,Trigger_priv,Create_role_priv,Drop_role_priv,account_locked from mysql.user;", p.decodeUserTableRow)
Expand Down Expand Up @@ -381,6 +430,31 @@ func (p *MySQLPrivilege) decodeTablesPrivTableRow(row chunk.Row, fs []*ast.Resul
return nil
}

func (p *MySQLPrivilege) decodeRoleEdgesTable(row chunk.Row, fs []*ast.ResultField) error {
var fromUser, fromHost, toHost, toUser string
for i, f := range fs {
switch {
case f.ColumnAsName.L == "from_host":
fromHost = row.GetString(i)
case f.ColumnAsName.L == "from_user":
fromUser = row.GetString(i)
case f.ColumnAsName.L == "to_host":
toHost = row.GetString(i)
case f.ColumnAsName.L == "to_user":
toUser = row.GetString(i)
}
}
fromKey := fromUser + "@" + fromHost
toKey := toUser + "@" + toHost
roleGraph, ok := p.RoleGraph[toKey]
if !ok {
roleGraph = roleGraphEdgesTable{roleList: make(map[string]bool)}
p.RoleGraph[toKey] = roleGraph
}
roleGraph.roleList[fromKey] = true
return nil
}

func (p *MySQLPrivilege) decodeColumnsPrivTableRow(row chunk.Row, fs []*ast.ResultField) error {
var value columnsPrivRecord
for i, f := range fs {
Expand Down
30 changes: 30 additions & 0 deletions privilege/privileges/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,36 @@ func (s *testCacheSuite) TestCaseInsensitive(c *C) {
c.Assert(p.RequestVerification("genius", "127.0.0.1", "tctrain", "tctrainorder", "", mysql.SelectPriv), IsTrue)
}

func (s *testCacheSuite) TestLoadRoleGraph(c *C) {
se, err := session.CreateSession4Test(s.store)
c.Assert(err, IsNil)
defer se.Close()
mustExec(c, se, "use mysql;")
mustExec(c, se, "truncate table user;")

var p privileges.MySQLPrivilege
err = p.LoadRoleGraph(se)
c.Assert(err, IsNil)
c.Assert(len(p.User), Equals, 0)

mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_1", "%", "user2")`)
mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_2", "%", "root")`)
mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_3", "%", "user1")`)
mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_4", "%", "root")`)

p = privileges.MySQLPrivilege{}
err = p.LoadRoleGraph(se)
c.Assert(err, IsNil)
graph := p.RoleGraph
c.Assert(graph["root@%"].Find("r_2", "%"), Equals, true)
c.Assert(graph["root@%"].Find("r_4", "%"), Equals, true)
c.Assert(graph["user2@%"].Find("r_1", "%"), Equals, true)
c.Assert(graph["user1@%"].Find("r_3", "%"), Equals, true)
_, ok := graph["illedal"]
c.Assert(ok, Equals, false)
c.Assert(graph["root@%"].Find("r_1", "%"), Equals, false)
}

func (s *testCacheSuite) TestAbnormalMySQLTable(c *C) {
store, err := mockstore.NewMockTikvStore()
c.Assert(err, IsNil)
Expand Down
16 changes: 16 additions & 0 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,19 @@ func (p *UserPrivileges) ShowGrants(ctx sessionctx.Context, user *auth.UserIdent

return
}

// ActiveRoles implements privilege.Manager ActiveRoles interface.
func (p *UserPrivileges) ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string) {
mysqlPrivilege := p.Handle.Get()
u := p.user
h := p.host
for _, r := range roleList {
ok := mysqlPrivilege.FindRole(u, h, r)
if !ok {
log.Errorf("Role: %+v doesn't grant for user", r)
return false, r.String()
}
}
ctx.GetSessionVars().ActiveRoles = roleList
return true, ""
}
4 changes: 4 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ type SessionVars struct {
// params for prepared statements
PreparedParams []types.Datum

// ActiveRoles stores active roles for current user
ActiveRoles []*auth.RoleIdentity

// retry information
RetryInfo *RetryInfo
// Should be reset on transaction finished.
Expand Down Expand Up @@ -350,6 +353,7 @@ func NewSessionVars() *SessionVars {
TxnCtx: &TransactionContext{},
KVVars: kv.NewVariables(),
RetryInfo: &RetryInfo{},
ActiveRoles: make([]*auth.RoleIdentity, 0, 10),
StrictSQLMode: true,
Status: mysql.ServerStatusAutocommit,
StmtCtx: new(stmtctx.StatementContext),
Expand Down

0 comments on commit 778c3f4

Please sign in to comment.