diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index 150a47d896403..c306443457caa 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -268,6 +268,18 @@ type MySQLPrivilege struct { RoleGraph map[string]roleGraphEdgesTable } +// FindAllUserEffectiveRoles is used to find all effective roles grant to this user. +// This method will filter out the roles that are not granted to the user but are still in activeRoles +func (p *MySQLPrivilege) FindAllUserEffectiveRoles(user, host string, activeRoles []*auth.RoleIdentity) []*auth.RoleIdentity { + grantedActiveRoles := make([]*auth.RoleIdentity, 0, len(activeRoles)) + for _, role := range activeRoles { + if p.FindRole(user, host, role) { + grantedActiveRoles = append(grantedActiveRoles, role) + } + } + return p.FindAllRole(grantedActiveRoles) +} + // FindAllRole is used to find all roles grant to this user. func (p *MySQLPrivilege) FindAllRole(activeRoles []*auth.RoleIdentity) []*auth.RoleIdentity { queue, head := make([]*auth.RoleIdentity, 0, len(activeRoles)), 0 @@ -968,7 +980,7 @@ func (p *MySQLPrivilege) matchColumns(user, host, db, table, column string) *col // without accepting SUPER privilege as a fallback. func (p *MySQLPrivilege) HasExplicitlyGrantedDynamicPrivilege(activeRoles []*auth.RoleIdentity, user, host, privName string, withGrant bool) bool { privName = strings.ToUpper(privName) - roleList := p.FindAllRole(activeRoles) + roleList := p.FindAllUserEffectiveRoles(user, host, activeRoles) roleList = append(roleList, &auth.RoleIdentity{Username: user, Hostname: host}) // Loop through each of the roles and return on first match // If grantable is required, ensure the record has the GrantOption set. @@ -1016,7 +1028,7 @@ func (p *MySQLPrivilege) RequestVerification(activeRoles []*auth.RoleIdentity, u return true } - roleList := p.FindAllRole(activeRoles) + roleList := p.FindAllUserEffectiveRoles(user, host, activeRoles) roleList = append(roleList, &auth.RoleIdentity{Username: user, Hostname: host}) var userPriv, dbPriv, tablePriv, columnPriv mysql.PrivilegeType @@ -1117,7 +1129,7 @@ func (p *MySQLPrivilege) showGrants(user, host string, roles []*auth.RoleIdentit var hasGlobalGrant = false // Some privileges may granted from role inheritance. // We should find these inheritance relationship. - allRoles := p.FindAllRole(roles) + allRoles := p.FindAllUserEffectiveRoles(user, host, roles) // Show global grants. var currentPriv mysql.PrivilegeType var userExists = false diff --git a/privilege/privileges/cache_test.go b/privilege/privileges/cache_test.go index 307ac7defc315..a975d3f4dadac 100644 --- a/privilege/privileges/cache_test.go +++ b/privilege/privileges/cache_test.go @@ -381,6 +381,46 @@ func TestRoleGraphBFS(t *testing.T) { require.Len(t, ret, 6) } +func TestFindAllUserEffectiveRoles(t *testing.T) { + t.Parallel() + store, clean := newStore(t) + defer clean() + + se, err := session.CreateSession4Test(store) + require.NoError(t, err) + defer se.Close() + mustExec(t, se, `CREATE USER u1`) + mustExec(t, se, `CREATE ROLE r_1, r_2, r_3, r_4;`) + mustExec(t, se, `GRANT r_3 to r_1`) + mustExec(t, se, `GRANT r_4 to r_2`) + mustExec(t, se, `GRANT r_1 to u1`) + mustExec(t, se, `GRANT r_2 to u1`) + + var p privileges.MySQLPrivilege + err = p.LoadAll(se) + require.NoError(t, err) + ret := p.FindAllUserEffectiveRoles("u1", "%", []*auth.RoleIdentity{ + {Username: "r_1", Hostname: "%"}, + {Username: "r_2", Hostname: "%"}, + }) + require.Equal(t, 4, len(ret)) + require.Equal(t, "r_1", ret[0].Username) + require.Equal(t, "r_2", ret[1].Username) + require.Equal(t, "r_3", ret[2].Username) + require.Equal(t, "r_4", ret[3].Username) + + mustExec(t, se, `REVOKE r_2 from u1`) + err = p.LoadAll(se) + require.NoError(t, err) + ret = p.FindAllUserEffectiveRoles("u1", "%", []*auth.RoleIdentity{ + {Username: "r_1", Hostname: "%"}, + {Username: "r_2", Hostname: "%"}, + }) + require.Equal(t, 2, len(ret)) + require.Equal(t, "r_1", ret[0].Username) + require.Equal(t, "r_3", ret[1].Username) +} + func TestAbnormalMySQLTable(t *testing.T) { t.Parallel() store, clean := newStore(t) diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index 104c2c3782387..ea7059d72804e 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -542,7 +542,7 @@ func (p *UserPrivileges) DBIsVisible(activeRoles []*auth.RoleIdentity, db string if mysqlPriv.DBIsVisible(p.user, p.host, db) { return true } - allRoles := mysqlPriv.FindAllRole(activeRoles) + allRoles := mysqlPriv.FindAllUserEffectiveRoles(p.user, p.host, activeRoles) for _, role := range allRoles { if mysqlPriv.DBIsVisible(role.Username, role.Hostname, db) { return true diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index 6d9c6dd0d62bc..8bc42a921a3a2 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -2971,3 +2971,34 @@ func TestSkipGrantTable(t *testing.T) { tk.MustExec(`GRANT RESTRICTED_TABLES_ADMIN ON *.* TO 'test2'@'%';`) tk.MustExec(`GRANT RESTRICTED_USER_ADMIN ON *.* TO 'test2'@'%';`) } + +func TestIssue29823(t *testing.T) { + t.Parallel() + store, clean := newStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create user u1") + tk.MustExec("create role r1") + tk.MustExec("create table t1 (c1 int)") + tk.MustExec("grant select on t1 to r1") + tk.MustExec("grant r1 to u1") + + tk2 := testkit.NewTestKit(t, store) + require.True(t, tk2.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "%"}, nil, nil)) + tk2.MustExec("set role all") + tk2.MustQuery("select current_role()").Check(testkit.Rows("`r1`@`%`")) + tk2.MustQuery("select * from test.t1").Check(testkit.Rows()) + tk2.MustQuery("show databases like 'test'").Check(testkit.Rows("test")) + tk2.MustQuery("show tables from test").Check(testkit.Rows("t1")) + + tk.MustExec("revoke r1 from u1") + tk2.MustQuery("select current_role()").Check(testkit.Rows("`r1`@`%`")) + err := tk2.ExecToErr("select * from test.t1") + require.EqualError(t, err, "[planner:1142]SELECT command denied to user 'u1'@'%' for table 't1'") + tk2.MustQuery("show databases like 'test'").Check(testkit.Rows()) + err = tk2.QueryToErr("show tables from test") + require.EqualError(t, err, "[executor:1044]Access denied for user 'u1'@'%' to database 'test'") +}