From f3a4a53a55a5c8d0221ae94f3c28cb9d7f97ed7f Mon Sep 17 00:00:00 2001 From: imtbkcat Date: Tue, 16 Jul 2019 13:36:10 +0800 Subject: [PATCH 1/3] fix serveral bug of rbac --- executor/simple.go | 19 +++++++++++++++++-- expression/builtin_info.go | 11 +++++++++-- privilege/privileges/cache.go | 8 +++++++- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/executor/simple.go b/executor/simple.go index be97225d5629a..4c0af96d0874b 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -471,13 +471,20 @@ func (e *SimpleExec) executeRevokeRole(s *ast.RevokeRoleStmt) error { } return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", role.String()) } + sql = fmt.Sprintf(`DELETE IGNORE FROM %s.%s WHERE DEFAULT_ROLE_HOST='%s' and DEFAULT_ROLE_USER='%s' and HOST='%s' and USER='%s'`, mysql.SystemDB, mysql.DefaultRoleTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + return errors.Trace(err) + } + return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", role.String()) + } } } if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "commit"); err != nil { return err } - err := domain.GetDomain(e.ctx).PrivilegeHandle().Update(e.ctx.(sessionctx.Context)) - return errors.Trace(err) + domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) + return nil } func (e *SimpleExec) executeCommit(s *ast.CommitStmt) { @@ -597,6 +604,14 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { func (e *SimpleExec) executeGrantRole(s *ast.GrantRoleStmt) error { failedUsers := make([]string, 0, len(s.Users)) + sessionVars := e.ctx.GetSessionVars() + for i, user := range s.Users { + if user.CurrentUser { + s.Users[i].Username = sessionVars.User.AuthUsername + s.Users[i].Hostname = sessionVars.User.AuthHostname + } + } + for _, role := range s.Roles { exists, err := userExists(e.ctx, role.Username, role.Hostname) if err != nil { diff --git a/expression/builtin_info.go b/expression/builtin_info.go index 0653afef87847..2e6ab91c6b5c5 100644 --- a/expression/builtin_info.go +++ b/expression/builtin_info.go @@ -18,6 +18,8 @@ package expression import ( + "sort" + "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" @@ -192,8 +194,13 @@ func (b *builtinCurrentRoleSig) evalString(row chunk.Row) (string, bool, error) return "", false, nil } res := "" - for i, r := range data.ActiveRoles { - res += r.String() + sortedRes := make([]string, 0, 10) + for _, r := range data.ActiveRoles { + sortedRes = append(sortedRes, r.String()) + } + sort.Strings(sortedRes) + for i, r := range sortedRes { + res += r if i != len(data.ActiveRoles)-1 { res += "," } diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index d574a8759dfbc..6fd962923039c 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -848,13 +848,19 @@ func (p *MySQLPrivilege) showGrants(user, host string, roles []*auth.RoleIdentit edgeTable, ok := p.RoleGraph[graphKey] g = "" if ok { + sortedRes := make([]string, 0, 10) for k := range edgeTable.roleList { role := strings.Split(k, "@") roleName, roleHost := role[0], role[1] + tmp := fmt.Sprintf("'%s'@'%s'", roleName, roleHost) + sortedRes = append(sortedRes, tmp) + } + sort.Strings(sortedRes) + for _, r := range sortedRes { + g += r if g != "" { g += ", " } - g += fmt.Sprintf("'%s'@'%s'", roleName, roleHost) } s := fmt.Sprintf(`GRANT %s TO '%s'@'%s'`, g, user, host) gs = append(gs, s) From 05a881cb3a31f353c13b8b3869928d39d858aae8 Mon Sep 17 00:00:00 2001 From: imtbkcat Date: Tue, 16 Jul 2019 14:37:18 +0800 Subject: [PATCH 2/3] add test --- executor/simple_test.go | 15 +++++++++++++++ privilege/privileges/cache.go | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/executor/simple_test.go b/executor/simple_test.go index 4a0499f9e9b73..ef71db312a840 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -136,6 +136,15 @@ func (s *testSuite3) TestRole(c *C) { grantRoleSQL = `GRANT 'r_1'@'localhost' TO 'r_3'@'localhost', 'r_4'@'localhost';` _, err = tk.Exec(grantRoleSQL) c.Check(err, NotNil) + + // Test grant role for current_user(); + sessionVars := tk.Se.GetSessionVars() + originUser := sessionVars.User + sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost", AuthUsername: "root", AuthHostname: "%"} + tk.MustExec("grant 'r_1'@'localhost' to current_user();") + tk.MustExec("revoke 'r_1'@'localhost' from 'root'@'%';") + sessionVars.User = originUser + result = tk.MustQuery(`SELECT FROM_USER FROM mysql.role_edges WHERE TO_USER="r_3" and TO_HOST="localhost"`) result.Check(nil) @@ -152,14 +161,20 @@ func (s *testSuite3) TestRole(c *C) { tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('localhost','test','%','root')") tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('%','r_1','%','root')") tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('%','r_2','%','root')") + tk.MustExec("flush privileges") + tk.MustExec("SET DEFAULT ROLE r_1, r_2 TO root") _, err = tk.Exec("revoke test@localhost, r_1 from root;") c.Check(err, IsNil) _, err = tk.Exec("revoke `r_2`@`%` from root, u_2;") c.Check(err, NotNil) _, err = tk.Exec("revoke `r_2`@`%` from root;") c.Check(err, IsNil) + _, err = tk.Exec("revoke `r_1`@`%` from root;") + c.Check(err, IsNil) result = tk.MustQuery(`SELECT * FROM mysql.default_roles WHERE DEFAULT_ROLE_USER="test" and DEFAULT_ROLE_HOST="localhost"`) result.Check(nil) + result = tk.MustQuery(`SELECT * FROM mysql.default_roles WHERE USER="root" and HOST="%"`) + result.Check(nil) dropRoleSQL = `DROP ROLE 'test'@'localhost', r_1, r_2;` tk.MustExec(dropRoleSQL) } diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index 6fd962923039c..daf8820039da3 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -856,9 +856,9 @@ func (p *MySQLPrivilege) showGrants(user, host string, roles []*auth.RoleIdentit sortedRes = append(sortedRes, tmp) } sort.Strings(sortedRes) - for _, r := range sortedRes { + for i, r := range sortedRes { g += r - if g != "" { + if i != len(sortedRes) - 1 { g += ", " } } From c4bcd6b1ac220f0bfbbc318bbe07da1337caadeb Mon Sep 17 00:00:00 2001 From: imtbkcat Date: Tue, 16 Jul 2019 15:19:20 +0800 Subject: [PATCH 3/3] gofmt project --- privilege/privileges/cache.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index daf8820039da3..699bfe616dacc 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -858,7 +858,7 @@ func (p *MySQLPrivilege) showGrants(user, host string, roles []*auth.RoleIdentit sort.Strings(sortedRes) for i, r := range sortedRes { g += r - if i != len(sortedRes) - 1 { + if i != len(sortedRes)-1 { g += ", " } }