diff --git a/executor/simple.go b/executor/simple.go index 8f645b176d2f5..5cd1ef5259d95 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -87,11 +87,152 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err erro err = e.executeSetRole(x) case *ast.RevokeRoleStmt: err = e.executeRevokeRole(x) + case *ast.SetDefaultRoleStmt: + err = e.executeSetDefaultRole(x) } e.done = true return err } +func (e *SimpleExec) setDefaultRoleNone(s *ast.SetDefaultRoleStmt) error { + sqlExecutor := e.ctx.(sqlexec.SQLExecutor) + if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + return err + } + for _, u := range s.UserList { + if u.Hostname == "" { + u.Hostname = "%" + } + sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", u.Username, u.Hostname) + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql)) + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return err + } + } + if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + return err + } + return nil +} + +func (e *SimpleExec) setDefaultRoleRegular(s *ast.SetDefaultRoleStmt) error { + for _, user := range s.UserList { + exists, err := userExists(e.ctx, user.Username, user.Hostname) + if err != nil { + return err + } + if !exists { + return ErrCannotUser.GenWithStackByArgs("SET DEFAULT ROLE", user.String()) + } + } + for _, role := range s.RoleList { + exists, err := userExists(e.ctx, role.Username, role.Hostname) + if err != nil { + return err + } + if !exists { + return ErrCannotUser.GenWithStackByArgs("SET DEFAULT ROLE", role.String()) + } + } + sqlExecutor := e.ctx.(sqlexec.SQLExecutor) + if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + return err + } + for _, user := range s.UserList { + if user.Hostname == "" { + user.Hostname = "%" + } + sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql)) + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return err + } + for _, role := range s.RoleList { + sql := fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles values('%s', '%s', '%s', '%s');", user.Hostname, user.Username, role.Hostname, role.Username) + checker := privilege.GetPrivilegeManager(e.ctx) + ok := checker.FindEdge(e.ctx, role, user) + if ok { + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql)) + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return err + } + } else { + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) + } + } + } + if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + return err + } + return nil +} + +func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error { + for _, user := range s.UserList { + exists, err := userExists(e.ctx, user.Username, user.Hostname) + if err != nil { + return err + } + if !exists { + return ErrCannotUser.GenWithStackByArgs("SET DEFAULT ROLE", user.String()) + } + } + sqlExecutor := e.ctx.(sqlexec.SQLExecutor) + if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + return err + } + for _, user := range s.UserList { + if user.Hostname == "" { + user.Hostname = "%" + } + sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql)) + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return err + } + sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+ + "SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username) + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return err + } + } + if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + return err + } + return nil +} + +func (e *SimpleExec) executeSetDefaultRole(s *ast.SetDefaultRoleStmt) error { + switch s.SetRoleOpt { + case ast.SetRoleAll: + return e.setDefaultRoleAll(s) + case ast.SetRoleNone: + return e.setDefaultRoleNone(s) + case ast.SetRoleRegular: + return e.setDefaultRoleRegular(s) + } + err := domain.GetDomain(e.ctx).PrivilegeHandle().Update(e.ctx.(sessionctx.Context)) + return err +} + func (e *SimpleExec) executeSetRole(s *ast.SetRoleStmt) error { checkDup := make(map[string]*auth.RoleIdentity, len(s.RoleList)) // Check whether RoleNameList contain duplicate role name. diff --git a/executor/simple_test.go b/executor/simple_test.go index 6cc38f1778af5..aaeda7a67beec 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -160,6 +160,56 @@ func (s *testSuite3) TestRole(c *C) { tk.MustExec(dropRoleSQL) } +func (s *testSuite3) TestDefaultRole(c *C) { + tk := testkit.NewTestKit(c, s.store) + + createRoleSQL := `CREATE ROLE r_1, r_2, r_3, u_1;` + tk.MustExec(createRoleSQL) + + tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('%','r_1','%','u_1')") + tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('%','r_2','%','u_1')") + + tk.MustExec("flush privileges;") + + setRoleSQL := `SET DEFAULT ROLE r_3 TO u_1;` + _, err := tk.Exec(setRoleSQL) + c.Check(err, NotNil) + + setRoleSQL = `SET DEFAULT ROLE r_1 TO u_1000;` + _, err = tk.Exec(setRoleSQL) + c.Check(err, NotNil) + + setRoleSQL = `SET DEFAULT ROLE r_1, r_3 TO u_1;` + _, err = tk.Exec(setRoleSQL) + c.Check(err, NotNil) + + setRoleSQL = `SET DEFAULT ROLE r_1 TO u_1;` + _, err = tk.Exec(setRoleSQL) + c.Check(err, IsNil) + result := tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`) + result.Check(testkit.Rows("r_1")) + setRoleSQL = `SET DEFAULT ROLE r_2 TO u_1;` + _, err = tk.Exec(setRoleSQL) + c.Check(err, IsNil) + result = tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`) + result.Check(testkit.Rows("r_2")) + + setRoleSQL = `SET DEFAULT ROLE ALL TO u_1;` + _, err = tk.Exec(setRoleSQL) + c.Check(err, IsNil) + result = tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`) + result.Check(testkit.Rows("r_1", "r_2")) + + setRoleSQL = `SET DEFAULT ROLE NONE TO u_1;` + _, err = tk.Exec(setRoleSQL) + c.Check(err, IsNil) + result = tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`) + result.Check(nil) + + dropRoleSQL := `DROP USER r_1, r_2, r_3, u_1;` + tk.MustExec(dropRoleSQL) +} + func (s *testSuite3) TestUser(c *C) { tk := testkit.NewTestKit(c, s.store) // Make sure user test not in mysql.User. diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 5213025746245..bd408d75f952f 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -268,7 +268,7 @@ func (b *PlanBuilder) Build(node ast.Node) (Plan, error) { case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt, *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt, - *ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt: + *ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt: return b.buildSimple(node.(ast.StmtNode)) case ast.DDLNode: return b.buildDDL(x) @@ -1095,7 +1095,7 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) { err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER") b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err) } - case *ast.AlterUserStmt: + case *ast.AlterUserStmt, *ast.SetDefaultRoleStmt: err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER") b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err) case *ast.GrantStmt: diff --git a/privilege/privilege.go b/privilege/privilege.go index b6aa136c6c806..e94ae4e9f8dae 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -56,6 +56,12 @@ type Manager interface { // ActiveRoles active roles for current session. // The first illegal role will be returned. ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string) + + // FindEdge find if there is an edge between role and user. + FindEdge(ctx sessionctx.Context, role *auth.RoleIdentity, user *auth.UserIdentity) bool + + // GetDefaultRoles returns all default roles for certain user. + GetDefaultRoles(user, host string) []*auth.RoleIdentity } const key keyType = 0 diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index 2731f6f39dad1..07055813ae2e3 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -68,7 +68,7 @@ type dbRecord struct { User string Privileges mysql.PrivilegeType - // patChars is compiled from Host and DB, cached for pattern match performance. + // hostPatChars is compiled from Host and DB, cached for pattern match performance. hostPatChars []byte hostPatTypes []byte @@ -105,7 +105,19 @@ type columnsPrivRecord struct { patTypes []byte } -// RoleGraphEdgesTable is used to cache relationship between and role. +// defaultRoleRecord is used to cache mysql.default_roles +type defaultRoleRecord struct { + Host string + User string + DefaultRoleUser string + DefaultRoleHost string + + // patChars is compiled from Host, cached for pattern match performance. + patChars []byte + patTypes []byte +} + +// roleGraphEdgesTable is used to cache relationship between and role. type roleGraphEdgesTable struct { roleList map[string]bool } @@ -125,11 +137,12 @@ func (g roleGraphEdgesTable) Find(user, host string) bool { // MySQLPrivilege is the in-memory cache of mysql privilege tables. type MySQLPrivilege struct { - User []UserRecord - DB []dbRecord - TablesPriv []tablesPrivRecord - ColumnsPriv []columnsPrivRecord - RoleGraph map[string]roleGraphEdgesTable + User []UserRecord + DB []dbRecord + TablesPriv []tablesPrivRecord + ColumnsPriv []columnsPrivRecord + DefaultRoles []defaultRoleRecord + RoleGraph map[string]roleGraphEdgesTable } // FindRole is used to detect whether there is edges between users and roles. @@ -166,6 +179,14 @@ func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error { log.Warn("mysql.tables_priv missing") } + err = p.LoadDefaultRoles(ctx) + if err != nil { + if !noSuchTable(err) { + return errors.Trace(err) + } + log.Warn("mysql.default_roles missing") + } + err = p.LoadColumnsPrivTable(ctx) if err != nil { if !noSuchTable(err) { @@ -316,6 +337,11 @@ func (p *MySQLPrivilege) LoadColumnsPrivTable(ctx sessionctx.Context) error { return p.loadTable(ctx, "select HIGH_PRIORITY Host,DB,User,Table_name,Column_name,Timestamp,Column_priv from mysql.columns_priv", p.decodeColumnsPrivTableRow) } +// LoadDefaultRoles loads the mysql.columns_priv table from database. +func (p *MySQLPrivilege) LoadDefaultRoles(ctx sessionctx.Context) error { + return p.loadTable(ctx, "select HOST, USER, DEFAULT_ROLE_HOST, DEFAULT_ROLE_USER from mysql.default_roles", p.decodeDefaultRoleTableRow) +} + func (p *MySQLPrivilege) loadTable(sctx sessionctx.Context, sql string, decodeTableRow func(chunk.Row, []*ast.ResultField) error) error { ctx := context.Background() @@ -455,6 +481,25 @@ func (p *MySQLPrivilege) decodeRoleEdgesTable(row chunk.Row, fs []*ast.ResultFie return nil } +func (p *MySQLPrivilege) decodeDefaultRoleTableRow(row chunk.Row, fs []*ast.ResultField) error { + var value defaultRoleRecord + for i, f := range fs { + switch { + case f.ColumnAsName.L == "host": + value.Host = row.GetString(i) + value.patChars, value.patTypes = stringutil.CompilePattern(value.Host, '\\') + case f.ColumnAsName.L == "user": + value.User = row.GetString(i) + case f.ColumnAsName.L == "default_role_host": + value.DefaultRoleHost = row.GetString(i) + case f.ColumnAsName.L == "default_role_user": + value.DefaultRoleUser = row.GetString(i) + } + } + p.DefaultRoles = append(p.DefaultRoles, value) + return nil +} + func (p *MySQLPrivilege) decodeColumnsPrivTableRow(row chunk.Row, fs []*ast.ResultField) error { var value columnsPrivRecord for i, f := range fs { @@ -522,6 +567,10 @@ func (record *columnsPrivRecord) match(user, host, db, table, col string) bool { patternMatch(host, record.patChars, record.patTypes) } +func (record *defaultRoleRecord) match(user, host string) bool { + return record.User == user && patternMatch(host, record.patChars, record.patTypes) +} + // patternMatch matches "%" the same way as ".*" in regular expression, for example, // "10.0.%" would match "10.0.1" "10.0.1.118" ... func patternMatch(str string, patChars, patTypes []byte) bool { @@ -766,6 +815,16 @@ func appendUserPrivilegesTableRow(rows [][]types.Datum, user UserRecord) [][]typ return rows } +func (p *MySQLPrivilege) getDefaultRoles(user, host string) []*auth.RoleIdentity { + ret := make([]*auth.RoleIdentity, 0) + for _, r := range p.DefaultRoles { + if r.match(user, host) { + ret = append(ret, &auth.RoleIdentity{Username: r.DefaultRoleUser, Hostname: r.DefaultRoleHost}) + } + } + return ret +} + // Handle wraps MySQLPrivilege providing thread safe access. type Handle struct { priv atomic.Value diff --git a/privilege/privileges/cache_test.go b/privilege/privileges/cache_test.go index 032b3586ee120..fe9e6c740035e 100644 --- a/privilege/privileges/cache_test.go +++ b/privilege/privileges/cache_test.go @@ -134,6 +134,25 @@ func (s *testCacheSuite) TestLoadColumnsPrivTable(c *C) { c.Assert(p.ColumnsPriv[1].ColumnPriv, Equals, mysql.SelectPriv) } +func (s *testCacheSuite) TestLoadDefaultRoleTable(c *C) { + se, err := session.CreateSession4Test(s.store) + c.Assert(err, IsNil) + defer se.Close() + mustExec(c, se, "use mysql;") + mustExec(c, se, "truncate table default_roles") + + mustExec(c, se, `INSERT INTO mysql.default_roles VALUES ("%", "test_default_roles", "localhost", "r_1")`) + mustExec(c, se, `INSERT INTO mysql.default_roles VALUES ("%", "test_default_roles", "localhost", "r_2")`) + var p privileges.MySQLPrivilege + err = p.LoadDefaultRoles(se) + c.Assert(err, IsNil) + c.Assert(p.DefaultRoles[0].Host, Equals, `%`) + c.Assert(p.DefaultRoles[0].User, Equals, "test_default_roles") + c.Assert(p.DefaultRoles[0].DefaultRoleHost, Equals, "localhost") + c.Assert(p.DefaultRoles[0].DefaultRoleUser, Equals, "r_1") + c.Assert(p.DefaultRoles[1].DefaultRoleHost, Equals, "localhost") +} + func (s *testCacheSuite) TestPatternMatch(c *C) { se, err := session.CreateSession4Test(s.store) c.Assert(err, IsNil) diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index a8d8a6cd7d816..98683977255cb 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -206,3 +206,21 @@ func (p *UserPrivileges) ActiveRoles(ctx sessionctx.Context, roleList []*auth.Ro ctx.GetSessionVars().ActiveRoles = roleList return true, "" } + +// FindEdge implements privilege.Manager FindRelationship interface. +func (p *UserPrivileges) FindEdge(ctx sessionctx.Context, role *auth.RoleIdentity, user *auth.UserIdentity) bool { + mysqlPrivilege := p.Handle.Get() + ok := mysqlPrivilege.FindRole(user.Username, user.Hostname, role) + if !ok { + logutil.Logger(context.Background()).Error("find role failed", zap.Stringer("role", role)) + return false + } + return true +} + +// GetDefaultRoles returns all default roles for certain user. +func (p *UserPrivileges) GetDefaultRoles(user, host string) []*auth.RoleIdentity { + mysqlPrivilege := p.Handle.Get() + ret := mysqlPrivilege.getDefaultRoles(user, host) + return ret +} diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index faa66d18ea3ed..1767d6bdd761e 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -494,6 +494,29 @@ func (s *testPrivilegeSuite) TestGetEncodedPassword(c *C) { c.Assert(pc.GetEncodedPassword("test_encode_u", "localhost"), Equals, "*81F5E21E35407D884A6CD4A731AEBFB6AF209E1B") } +func (s *testPrivilegeSuite) TestDefaultRoles(c *C) { + rootSe := newSession(c, s.store, s.dbName) + mustExec(c, rootSe, `CREATE USER 'testdefault'@'localhost';`) + mustExec(c, rootSe, `CREATE ROLE 'testdefault_r1'@'localhost', 'testdefault_r2'@'localhost';`) + mustExec(c, rootSe, `GRANT 'testdefault_r1'@'localhost', 'testdefault_r2'@'localhost' TO 'testdefault'@'localhost';`) + + se := newSession(c, s.store, s.dbName) + pc := privilege.GetPrivilegeManager(se) + + ret := pc.GetDefaultRoles("testdefault", "localhost") + c.Assert(len(ret), Equals, 0) + + mustExec(c, rootSe, `SET DEFAULT ROLE ALL TO 'testdefault'@'localhost';`) + mustExec(c, rootSe, `flush privileges;`) + ret = pc.GetDefaultRoles("testdefault", "localhost") + c.Assert(len(ret), Equals, 2) + + mustExec(c, rootSe, `SET DEFAULT ROLE NONE TO 'testdefault'@'localhost';`) + mustExec(c, rootSe, `flush privileges;`) + ret = pc.GetDefaultRoles("testdefault", "localhost") + c.Assert(len(ret), Equals, 0) +} + func mustExec(c *C, se session.Session, sql string) { _, err := se.Execute(context.Background(), sql) c.Assert(err, IsNil) diff --git a/server/server_test.go b/server/server_test.go index 57179c4c091dd..8271c2391b584 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -804,7 +804,10 @@ func runTestShowProcessList(c *C) { func runTestAuth(c *C) { runTests(c, nil, func(dbt *DBTest) { dbt.mustExec(`CREATE USER 'authtest'@'%' IDENTIFIED BY '123';`) + dbt.mustExec(`CREATE ROLE 'authtest_r1'@'%';`) dbt.mustExec(`GRANT ALL on test.* to 'authtest'`) + dbt.mustExec(`GRANT authtest_r1 to 'authtest'`) + dbt.mustExec(`SET DEFAULT ROLE authtest_r1 TO authtest`) dbt.mustExec(`FLUSH PRIVILEGES;`) }) runTests(c, func(config *mysql.Config) { @@ -823,6 +826,21 @@ func runTestAuth(c *C) { c.Assert(err, NotNil, Commentf("Wrong password should be failed")) db.Close() + // Test for loading active roles. + db, err = sql.Open("mysql", getDSN(func(config *mysql.Config) { + config.User = "authtest" + config.Passwd = "123" + })) + c.Assert(err, IsNil) + rows, err := db.Query("select current_role;") + c.Assert(err, IsNil) + c.Assert(rows.Next(), IsTrue) + var outA string + err = rows.Scan(&outA) + c.Assert(err, IsNil) + c.Assert(outA, Equals, "`authtest_r1`@`%`") + db.Close() + // Test login use IP that not exists in mysql.user. runTests(c, nil, func(dbt *DBTest) { dbt.mustExec(`CREATE USER 'authtest2'@'localhost' IDENTIFIED BY '123';`) diff --git a/session/session.go b/session/session.go index fb6d48d39eec1..c63d3f185fbc5 100644 --- a/session/session.go +++ b/session/session.go @@ -1196,6 +1196,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt) if success { s.sessionVars.User = user + s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) return true } else if user.Hostname == variable.DefHostname { logutil.Logger(context.Background()).Error("user connection verification failed", @@ -1213,6 +1214,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by AuthUsername: u, AuthHostname: h, } + s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h) return true } }