Skip to content

Commit

Permalink
privilege, executor: support SET DEFAULT ROLE (#9949)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lingyu Song authored and jackysp committed Apr 17, 2019
1 parent e5f734e commit abeddab
Show file tree
Hide file tree
Showing 10 changed files with 345 additions and 9 deletions.
141 changes: 141 additions & 0 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,152 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err erro
err = e.executeSetRole(x)
case *ast.RevokeRoleStmt:
err = e.executeRevokeRole(x)
case *ast.SetDefaultRoleStmt:
err = e.executeSetDefaultRole(x)
}
e.done = true
return err
}

func (e *SimpleExec) setDefaultRoleNone(s *ast.SetDefaultRoleStmt) error {
sqlExecutor := e.ctx.(sqlexec.SQLExecutor)
if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil {
return err
}
for _, u := range s.UserList {
if u.Hostname == "" {
u.Hostname = "%"
}
sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", u.Username, u.Hostname)
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
}
if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil {
return err
}
return nil
}

func (e *SimpleExec) setDefaultRoleRegular(s *ast.SetDefaultRoleStmt) error {
for _, user := range s.UserList {
exists, err := userExists(e.ctx, user.Username, user.Hostname)
if err != nil {
return err
}
if !exists {
return ErrCannotUser.GenWithStackByArgs("SET DEFAULT ROLE", user.String())
}
}
for _, role := range s.RoleList {
exists, err := userExists(e.ctx, role.Username, role.Hostname)
if err != nil {
return err
}
if !exists {
return ErrCannotUser.GenWithStackByArgs("SET DEFAULT ROLE", role.String())
}
}
sqlExecutor := e.ctx.(sqlexec.SQLExecutor)
if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil {
return err
}
for _, user := range s.UserList {
if user.Hostname == "" {
user.Hostname = "%"
}
sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname)
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
for _, role := range s.RoleList {
sql := fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles values('%s', '%s', '%s', '%s');", user.Hostname, user.Username, role.Hostname, role.Username)
checker := privilege.GetPrivilegeManager(e.ctx)
ok := checker.FindEdge(e.ctx, role, user)
if ok {
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
} else {
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String())
}
}
}
if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil {
return err
}
return nil
}

func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error {
for _, user := range s.UserList {
exists, err := userExists(e.ctx, user.Username, user.Hostname)
if err != nil {
return err
}
if !exists {
return ErrCannotUser.GenWithStackByArgs("SET DEFAULT ROLE", user.String())
}
}
sqlExecutor := e.ctx.(sqlexec.SQLExecutor)
if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil {
return err
}
for _, user := range s.UserList {
if user.Hostname == "" {
user.Hostname = "%"
}
sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname)
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+
"SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username)
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
}
if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil {
return err
}
return nil
}

func (e *SimpleExec) executeSetDefaultRole(s *ast.SetDefaultRoleStmt) error {
switch s.SetRoleOpt {
case ast.SetRoleAll:
return e.setDefaultRoleAll(s)
case ast.SetRoleNone:
return e.setDefaultRoleNone(s)
case ast.SetRoleRegular:
return e.setDefaultRoleRegular(s)
}
err := domain.GetDomain(e.ctx).PrivilegeHandle().Update(e.ctx.(sessionctx.Context))
return err
}

func (e *SimpleExec) executeSetRole(s *ast.SetRoleStmt) error {
checkDup := make(map[string]*auth.RoleIdentity, len(s.RoleList))
// Check whether RoleNameList contain duplicate role name.
Expand Down
50 changes: 50 additions & 0 deletions executor/simple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,56 @@ func (s *testSuite3) TestRole(c *C) {
tk.MustExec(dropRoleSQL)
}

func (s *testSuite3) TestDefaultRole(c *C) {
tk := testkit.NewTestKit(c, s.store)

createRoleSQL := `CREATE ROLE r_1, r_2, r_3, u_1;`
tk.MustExec(createRoleSQL)

tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('%','r_1','%','u_1')")
tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('%','r_2','%','u_1')")

tk.MustExec("flush privileges;")

setRoleSQL := `SET DEFAULT ROLE r_3 TO u_1;`
_, err := tk.Exec(setRoleSQL)
c.Check(err, NotNil)

setRoleSQL = `SET DEFAULT ROLE r_1 TO u_1000;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, NotNil)

setRoleSQL = `SET DEFAULT ROLE r_1, r_3 TO u_1;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, NotNil)

setRoleSQL = `SET DEFAULT ROLE r_1 TO u_1;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, IsNil)
result := tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`)
result.Check(testkit.Rows("r_1"))
setRoleSQL = `SET DEFAULT ROLE r_2 TO u_1;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, IsNil)
result = tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`)
result.Check(testkit.Rows("r_2"))

setRoleSQL = `SET DEFAULT ROLE ALL TO u_1;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, IsNil)
result = tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`)
result.Check(testkit.Rows("r_1", "r_2"))

setRoleSQL = `SET DEFAULT ROLE NONE TO u_1;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, IsNil)
result = tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`)
result.Check(nil)

dropRoleSQL := `DROP USER r_1, r_2, r_3, u_1;`
tk.MustExec(dropRoleSQL)
}

func (s *testSuite3) TestUser(c *C) {
tk := testkit.NewTestKit(c, s.store)
// Make sure user test not in mysql.User.
Expand Down
4 changes: 2 additions & 2 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func (b *PlanBuilder) Build(node ast.Node) (Plan, error) {
case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt,
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt,
*ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt,
*ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt:
*ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt:
return b.buildSimple(node.(ast.StmtNode))
case ast.DDLNode:
return b.buildDDL(x)
Expand Down Expand Up @@ -1095,7 +1095,7 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) {
err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err)
}
case *ast.AlterUserStmt:
case *ast.AlterUserStmt, *ast.SetDefaultRoleStmt:
err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err)
case *ast.GrantStmt:
Expand Down
6 changes: 6 additions & 0 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ type Manager interface {
// ActiveRoles active roles for current session.
// The first illegal role will be returned.
ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string)

// FindEdge find if there is an edge between role and user.
FindEdge(ctx sessionctx.Context, role *auth.RoleIdentity, user *auth.UserIdentity) bool

// GetDefaultRoles returns all default roles for certain user.
GetDefaultRoles(user, host string) []*auth.RoleIdentity
}

const key keyType = 0
Expand Down
73 changes: 66 additions & 7 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ type dbRecord struct {
User string
Privileges mysql.PrivilegeType

// patChars is compiled from Host and DB, cached for pattern match performance.
// hostPatChars is compiled from Host and DB, cached for pattern match performance.
hostPatChars []byte
hostPatTypes []byte

Expand Down Expand Up @@ -105,7 +105,19 @@ type columnsPrivRecord struct {
patTypes []byte
}

// RoleGraphEdgesTable is used to cache relationship between and role.
// defaultRoleRecord is used to cache mysql.default_roles
type defaultRoleRecord struct {
Host string
User string
DefaultRoleUser string
DefaultRoleHost string

// patChars is compiled from Host, cached for pattern match performance.
patChars []byte
patTypes []byte
}

// roleGraphEdgesTable is used to cache relationship between and role.
type roleGraphEdgesTable struct {
roleList map[string]bool
}
Expand All @@ -125,11 +137,12 @@ func (g roleGraphEdgesTable) Find(user, host string) bool {

// MySQLPrivilege is the in-memory cache of mysql privilege tables.
type MySQLPrivilege struct {
User []UserRecord
DB []dbRecord
TablesPriv []tablesPrivRecord
ColumnsPriv []columnsPrivRecord
RoleGraph map[string]roleGraphEdgesTable
User []UserRecord
DB []dbRecord
TablesPriv []tablesPrivRecord
ColumnsPriv []columnsPrivRecord
DefaultRoles []defaultRoleRecord
RoleGraph map[string]roleGraphEdgesTable
}

// FindRole is used to detect whether there is edges between users and roles.
Expand Down Expand Up @@ -166,6 +179,14 @@ func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error {
log.Warn("mysql.tables_priv missing")
}

err = p.LoadDefaultRoles(ctx)
if err != nil {
if !noSuchTable(err) {
return errors.Trace(err)
}
log.Warn("mysql.default_roles missing")
}

err = p.LoadColumnsPrivTable(ctx)
if err != nil {
if !noSuchTable(err) {
Expand Down Expand Up @@ -316,6 +337,11 @@ func (p *MySQLPrivilege) LoadColumnsPrivTable(ctx sessionctx.Context) error {
return p.loadTable(ctx, "select HIGH_PRIORITY Host,DB,User,Table_name,Column_name,Timestamp,Column_priv from mysql.columns_priv", p.decodeColumnsPrivTableRow)
}

// LoadDefaultRoles loads the mysql.columns_priv table from database.
func (p *MySQLPrivilege) LoadDefaultRoles(ctx sessionctx.Context) error {
return p.loadTable(ctx, "select HOST, USER, DEFAULT_ROLE_HOST, DEFAULT_ROLE_USER from mysql.default_roles", p.decodeDefaultRoleTableRow)
}

func (p *MySQLPrivilege) loadTable(sctx sessionctx.Context, sql string,
decodeTableRow func(chunk.Row, []*ast.ResultField) error) error {
ctx := context.Background()
Expand Down Expand Up @@ -455,6 +481,25 @@ func (p *MySQLPrivilege) decodeRoleEdgesTable(row chunk.Row, fs []*ast.ResultFie
return nil
}

func (p *MySQLPrivilege) decodeDefaultRoleTableRow(row chunk.Row, fs []*ast.ResultField) error {
var value defaultRoleRecord
for i, f := range fs {
switch {
case f.ColumnAsName.L == "host":
value.Host = row.GetString(i)
value.patChars, value.patTypes = stringutil.CompilePattern(value.Host, '\\')
case f.ColumnAsName.L == "user":
value.User = row.GetString(i)
case f.ColumnAsName.L == "default_role_host":
value.DefaultRoleHost = row.GetString(i)
case f.ColumnAsName.L == "default_role_user":
value.DefaultRoleUser = row.GetString(i)
}
}
p.DefaultRoles = append(p.DefaultRoles, value)
return nil
}

func (p *MySQLPrivilege) decodeColumnsPrivTableRow(row chunk.Row, fs []*ast.ResultField) error {
var value columnsPrivRecord
for i, f := range fs {
Expand Down Expand Up @@ -522,6 +567,10 @@ func (record *columnsPrivRecord) match(user, host, db, table, col string) bool {
patternMatch(host, record.patChars, record.patTypes)
}

func (record *defaultRoleRecord) match(user, host string) bool {
return record.User == user && patternMatch(host, record.patChars, record.patTypes)
}

// patternMatch matches "%" the same way as ".*" in regular expression, for example,
// "10.0.%" would match "10.0.1" "10.0.1.118" ...
func patternMatch(str string, patChars, patTypes []byte) bool {
Expand Down Expand Up @@ -766,6 +815,16 @@ func appendUserPrivilegesTableRow(rows [][]types.Datum, user UserRecord) [][]typ
return rows
}

func (p *MySQLPrivilege) getDefaultRoles(user, host string) []*auth.RoleIdentity {
ret := make([]*auth.RoleIdentity, 0)
for _, r := range p.DefaultRoles {
if r.match(user, host) {
ret = append(ret, &auth.RoleIdentity{Username: r.DefaultRoleUser, Hostname: r.DefaultRoleHost})
}
}
return ret
}

// Handle wraps MySQLPrivilege providing thread safe access.
type Handle struct {
priv atomic.Value
Expand Down
19 changes: 19 additions & 0 deletions privilege/privileges/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,25 @@ func (s *testCacheSuite) TestLoadColumnsPrivTable(c *C) {
c.Assert(p.ColumnsPriv[1].ColumnPriv, Equals, mysql.SelectPriv)
}

func (s *testCacheSuite) TestLoadDefaultRoleTable(c *C) {
se, err := session.CreateSession4Test(s.store)
c.Assert(err, IsNil)
defer se.Close()
mustExec(c, se, "use mysql;")
mustExec(c, se, "truncate table default_roles")

mustExec(c, se, `INSERT INTO mysql.default_roles VALUES ("%", "test_default_roles", "localhost", "r_1")`)
mustExec(c, se, `INSERT INTO mysql.default_roles VALUES ("%", "test_default_roles", "localhost", "r_2")`)
var p privileges.MySQLPrivilege
err = p.LoadDefaultRoles(se)
c.Assert(err, IsNil)
c.Assert(p.DefaultRoles[0].Host, Equals, `%`)
c.Assert(p.DefaultRoles[0].User, Equals, "test_default_roles")
c.Assert(p.DefaultRoles[0].DefaultRoleHost, Equals, "localhost")
c.Assert(p.DefaultRoles[0].DefaultRoleUser, Equals, "r_1")
c.Assert(p.DefaultRoles[1].DefaultRoleHost, Equals, "localhost")
}

func (s *testCacheSuite) TestPatternMatch(c *C) {
se, err := session.CreateSession4Test(s.store)
c.Assert(err, IsNil)
Expand Down
Loading

0 comments on commit abeddab

Please sign in to comment.