Skip to content

Commit

Permalink
privilege: disable role privilege when it is revoked from an user (#3…
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao committed Nov 24, 2021
1 parent 5916672 commit e65f548
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 4 deletions.
18 changes: 15 additions & 3 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,18 @@ type MySQLPrivilege struct {
RoleGraph map[string]roleGraphEdgesTable
}

// FindAllUserEffectiveRoles is used to find all effective roles grant to this user.
// This method will filter out the roles that are not granted to the user but are still in activeRoles
func (p *MySQLPrivilege) FindAllUserEffectiveRoles(user, host string, activeRoles []*auth.RoleIdentity) []*auth.RoleIdentity {
grantedActiveRoles := make([]*auth.RoleIdentity, 0, len(activeRoles))
for _, role := range activeRoles {
if p.FindRole(user, host, role) {
grantedActiveRoles = append(grantedActiveRoles, role)
}
}
return p.FindAllRole(grantedActiveRoles)
}

// FindAllRole is used to find all roles grant to this user.
func (p *MySQLPrivilege) FindAllRole(activeRoles []*auth.RoleIdentity) []*auth.RoleIdentity {
queue, head := make([]*auth.RoleIdentity, 0, len(activeRoles)), 0
Expand Down Expand Up @@ -968,7 +980,7 @@ func (p *MySQLPrivilege) matchColumns(user, host, db, table, column string) *col
// without accepting SUPER privilege as a fallback.
func (p *MySQLPrivilege) HasExplicitlyGrantedDynamicPrivilege(activeRoles []*auth.RoleIdentity, user, host, privName string, withGrant bool) bool {
privName = strings.ToUpper(privName)
roleList := p.FindAllRole(activeRoles)
roleList := p.FindAllUserEffectiveRoles(user, host, activeRoles)
roleList = append(roleList, &auth.RoleIdentity{Username: user, Hostname: host})
// Loop through each of the roles and return on first match
// If grantable is required, ensure the record has the GrantOption set.
Expand Down Expand Up @@ -1016,7 +1028,7 @@ func (p *MySQLPrivilege) RequestVerification(activeRoles []*auth.RoleIdentity, u
return true
}

roleList := p.FindAllRole(activeRoles)
roleList := p.FindAllUserEffectiveRoles(user, host, activeRoles)
roleList = append(roleList, &auth.RoleIdentity{Username: user, Hostname: host})

var userPriv, dbPriv, tablePriv, columnPriv mysql.PrivilegeType
Expand Down Expand Up @@ -1117,7 +1129,7 @@ func (p *MySQLPrivilege) showGrants(user, host string, roles []*auth.RoleIdentit
var hasGlobalGrant = false
// Some privileges may granted from role inheritance.
// We should find these inheritance relationship.
allRoles := p.FindAllRole(roles)
allRoles := p.FindAllUserEffectiveRoles(user, host, roles)
// Show global grants.
var currentPriv mysql.PrivilegeType
var userExists = false
Expand Down
40 changes: 40 additions & 0 deletions privilege/privileges/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,46 @@ func TestRoleGraphBFS(t *testing.T) {
require.Len(t, ret, 6)
}

func TestFindAllUserEffectiveRoles(t *testing.T) {
t.Parallel()
store, clean := newStore(t)
defer clean()

se, err := session.CreateSession4Test(store)
require.NoError(t, err)
defer se.Close()
mustExec(t, se, `CREATE USER u1`)
mustExec(t, se, `CREATE ROLE r_1, r_2, r_3, r_4;`)
mustExec(t, se, `GRANT r_3 to r_1`)
mustExec(t, se, `GRANT r_4 to r_2`)
mustExec(t, se, `GRANT r_1 to u1`)
mustExec(t, se, `GRANT r_2 to u1`)

var p privileges.MySQLPrivilege
err = p.LoadAll(se)
require.NoError(t, err)
ret := p.FindAllUserEffectiveRoles("u1", "%", []*auth.RoleIdentity{
{Username: "r_1", Hostname: "%"},
{Username: "r_2", Hostname: "%"},
})
require.Equal(t, 4, len(ret))
require.Equal(t, "r_1", ret[0].Username)
require.Equal(t, "r_2", ret[1].Username)
require.Equal(t, "r_3", ret[2].Username)
require.Equal(t, "r_4", ret[3].Username)

mustExec(t, se, `REVOKE r_2 from u1`)
err = p.LoadAll(se)
require.NoError(t, err)
ret = p.FindAllUserEffectiveRoles("u1", "%", []*auth.RoleIdentity{
{Username: "r_1", Hostname: "%"},
{Username: "r_2", Hostname: "%"},
})
require.Equal(t, 2, len(ret))
require.Equal(t, "r_1", ret[0].Username)
require.Equal(t, "r_3", ret[1].Username)
}

func TestAbnormalMySQLTable(t *testing.T) {
t.Parallel()
store, clean := newStore(t)
Expand Down
2 changes: 1 addition & 1 deletion privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ func (p *UserPrivileges) DBIsVisible(activeRoles []*auth.RoleIdentity, db string
if mysqlPriv.DBIsVisible(p.user, p.host, db) {
return true
}
allRoles := mysqlPriv.FindAllRole(activeRoles)
allRoles := mysqlPriv.FindAllUserEffectiveRoles(p.user, p.host, activeRoles)
for _, role := range allRoles {
if mysqlPriv.DBIsVisible(role.Username, role.Hostname, db) {
return true
Expand Down
31 changes: 31 additions & 0 deletions privilege/privileges/privileges_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2971,3 +2971,34 @@ func TestSkipGrantTable(t *testing.T) {
tk.MustExec(`GRANT RESTRICTED_TABLES_ADMIN ON *.* TO 'test2'@'%';`)
tk.MustExec(`GRANT RESTRICTED_USER_ADMIN ON *.* TO 'test2'@'%';`)
}

func TestIssue29823(t *testing.T) {
t.Parallel()
store, clean := newStore(t)
defer clean()

tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t1")
tk.MustExec("create user u1")
tk.MustExec("create role r1")
tk.MustExec("create table t1 (c1 int)")
tk.MustExec("grant select on t1 to r1")
tk.MustExec("grant r1 to u1")

tk2 := testkit.NewTestKit(t, store)
require.True(t, tk2.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "%"}, nil, nil))
tk2.MustExec("set role all")
tk2.MustQuery("select current_role()").Check(testkit.Rows("`r1`@`%`"))
tk2.MustQuery("select * from test.t1").Check(testkit.Rows())
tk2.MustQuery("show databases like 'test'").Check(testkit.Rows("test"))
tk2.MustQuery("show tables from test").Check(testkit.Rows("t1"))

tk.MustExec("revoke r1 from u1")
tk2.MustQuery("select current_role()").Check(testkit.Rows("`r1`@`%`"))
err := tk2.ExecToErr("select * from test.t1")
require.EqualError(t, err, "[planner:1142]SELECT command denied to user 'u1'@'%' for table 't1'")
tk2.MustQuery("show databases like 'test'").Check(testkit.Rows())
err = tk2.QueryToErr("show tables from test")
require.EqualError(t, err, "[executor:1044]Access denied for user 'u1'@'%' to database 'test'")
}

0 comments on commit e65f548

Please sign in to comment.