Skip to content

Commit

Permalink
executor, privileges: fix privileges check fail for SET DEFAULT ROLE(
Browse files Browse the repository at this point in the history
  • Loading branch information
Lingyu Song authored and sre-bot committed Aug 21, 2019
1 parent 7f15d28 commit 48a3574
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 1 deletion.
12 changes: 12 additions & 0 deletions executor/show_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ func (s *testSuite2) TestIssue10549(c *C) {
tk.MustQuery("SHOW GRANTS FOR CURRENT_USER").Check(testkit.Rows("GRANT USAGE ON *.* TO 'dev'@'%'", "GRANT 'app_developer'@'%' TO 'dev'@'%'"))
}

func (s *testSuite3) TestIssue11165(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("CREATE ROLE 'r_manager';")
tk.MustExec("CREATE USER 'manager'@'localhost';")
tk.MustExec("GRANT 'r_manager' TO 'manager'@'localhost';")

c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "manager", Hostname: "localhost", AuthUsername: "manager", AuthHostname: "localhost"}, nil, nil), IsTrue)
tk.MustExec("SET DEFAULT ROLE ALL TO 'manager'@'localhost';")
tk.MustExec("SET DEFAULT ROLE NONE TO 'manager'@'localhost';")
tk.MustExec("SET DEFAULT ROLE 'r_manager' TO 'manager'@'localhost';")
}

// TestShow2 is moved from session_test
func (s *testSuite2) TestShow2(c *C) {
tk := testkit.NewTestKit(c, s.store)
Expand Down
102 changes: 102 additions & 0 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"
"strings"

"github.com/ngaut/pools"
"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/auth"
Expand All @@ -28,6 +29,7 @@ import (
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/plugin"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
Expand All @@ -51,6 +53,24 @@ type SimpleExec struct {
is infoschema.InfoSchema
}

func (e *SimpleExec) getSysSession() (sessionctx.Context, error) {
dom := domain.GetDomain(e.ctx)
sysSessionPool := dom.SysSessionPool()
ctx, err := sysSessionPool.Get()
if err != nil {
return nil, err
}
restrictedCtx := ctx.(sessionctx.Context)
restrictedCtx.GetSessionVars().InRestrictedSQL = true
return restrictedCtx, nil
}

func (e *SimpleExec) releaseSysSession(ctx sessionctx.Context) {
dom := domain.GetDomain(e.ctx)
sysSessionPool := dom.SysSessionPool()
sysSessionPool.Put(ctx.(pools.Resource))
}

// Next implements the Executor Next interface.
func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) {
if e.done {
Expand Down Expand Up @@ -230,7 +250,89 @@ func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error {
return nil
}

func (e *SimpleExec) setDefaultRoleForCurrentUser(s *ast.SetDefaultRoleStmt) (err error) {
checker := privilege.GetPrivilegeManager(e.ctx)
user, sql := s.UserList[0], ""
if user.Hostname == "" {
user.Hostname = "%"
}
switch s.SetRoleOpt {
case ast.SetRoleNone:
sql = fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname)
case ast.SetRoleAll:
sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+
"SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username)
case ast.SetRoleRegular:
sql = "INSERT IGNORE INTO mysql.default_roles values"
for i, role := range s.RoleList {
ok := checker.FindEdge(e.ctx, role, user)
if !ok {
return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String())
}
sql += fmt.Sprintf("('%s', '%s', '%s', '%s')", user.Hostname, user.Username, role.Hostname, role.Username)
if i != len(s.RoleList)-1 {
sql += ","
}
}
}

restrictedCtx, err := e.getSysSession()
if err != nil {
return err
}
defer e.releaseSysSession(restrictedCtx)

sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor)

if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil {
return err
}

deleteSQL := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname)
if _, err := sqlExecutor.Execute(context.Background(), deleteSQL); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}

if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil {
return err
}
return nil
}

func (e *SimpleExec) executeSetDefaultRole(s *ast.SetDefaultRoleStmt) (err error) {
sessionVars := e.ctx.GetSessionVars()
checker := privilege.GetPrivilegeManager(e.ctx)
if checker == nil {
return errors.New("miss privilege checker")
}

if len(s.UserList) == 1 && sessionVars.User != nil {
u, h := s.UserList[0].Username, s.UserList[0].Hostname
if u == sessionVars.User.Username && h == sessionVars.User.AuthHostname {
err = e.setDefaultRoleForCurrentUser(s)
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return
}
}

activeRoles := sessionVars.ActiveRoles
if !checker.RequestVerification(activeRoles, mysql.SystemDB, mysql.DefaultRoleTable, "", mysql.UpdatePriv) {
if !checker.RequestVerification(activeRoles, "", "", "", mysql.CreateUserPriv) {
return core.ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER")
}
}

switch s.SetRoleOpt {
case ast.SetRoleAll:
err = e.setDefaultRoleAll(s)
Expand Down
2 changes: 1 addition & 1 deletion planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,7 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) {
err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err)
}
case *ast.AlterUserStmt, *ast.SetDefaultRoleStmt:
case *ast.AlterUserStmt:
err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err)
case *ast.GrantStmt:
Expand Down

0 comments on commit 48a3574

Please sign in to comment.