Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: make TiDB more robust when synchronize mysql.user table #2722

Merged
merged 5 commits into from
Feb 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 13 additions & 23 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"time"

"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
Expand Down Expand Up @@ -116,22 +117,22 @@ func (p *MySQLPrivilege) LoadAll(ctx context.Context) error {

// LoadUserTable loads the mysql.user table from database.
func (p *MySQLPrivilege) LoadUserTable(ctx context.Context) error {
return p.loadTable(ctx, "select * from mysql.user order by host, user;", p.decodeUserTableRow)
return p.loadTable(ctx, "select Host,User,Password,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Alter_priv,Show_db_priv,Execute_priv,Index_priv,Create_user_priv from mysql.user order by host, user;", p.decodeUserTableRow)
}

// LoadDBTable loads the mysql.db table from database.
func (p *MySQLPrivilege) LoadDBTable(ctx context.Context) error {
return p.loadTable(ctx, "select * from mysql.db order by host, db, user;", p.decodeDBTableRow)
return p.loadTable(ctx, "select Host,DB,User,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Index_priv,Alter_priv,Execute_priv from mysql.db order by host, db, user;", p.decodeDBTableRow)
}

// LoadTablesPrivTable loads the mysql.tables_priv table from database.
func (p *MySQLPrivilege) LoadTablesPrivTable(ctx context.Context) error {
return p.loadTable(ctx, "select * from mysql.tables_priv", p.decodeTablesPrivTableRow)
return p.loadTable(ctx, "select Host,DB,User,Table_name,Grantor,Timestamp,Table_priv,Column_priv from mysql.tables_priv", p.decodeTablesPrivTableRow)
}

// LoadColumnsPrivTable loads the mysql.columns_priv table from database.
func (p *MySQLPrivilege) LoadColumnsPrivTable(ctx context.Context) error {
return p.loadTable(ctx, "select * from mysql.columns_priv", p.decodeColumnsPrivTableRow)
return p.loadTable(ctx, "select Host,DB,User,Table_name,Column_name,Timestamp,Column_priv from mysql.columns_priv", p.decodeColumnsPrivTableRow)
}

func (p *MySQLPrivilege) loadTable(ctx context.Context, sql string,
Expand Down Expand Up @@ -235,17 +236,9 @@ func (p *MySQLPrivilege) decodeTablesPrivTableRow(row *ast.Row, fs []*ast.Result
case f.ColumnAsName.L == "table_name":
value.TableName = d.GetString()
case f.ColumnAsName.L == "table_priv":
priv, err := decodeSetToPrivilege(d.GetMysqlSet())
if err != nil {
return errors.Trace(err)
}
value.TablePriv = priv
value.TablePriv = decodeSetToPrivilege(d.GetMysqlSet())
case f.ColumnAsName.L == "column_priv":
priv, err := decodeSetToPrivilege(d.GetMysqlSet())
if err != nil {
return errors.Trace(err)
}
value.ColumnPriv = priv
value.ColumnPriv = decodeSetToPrivilege(d.GetMysqlSet())
}
}
p.TablesPriv = append(p.TablesPriv, value)
Expand All @@ -271,30 +264,27 @@ func (p *MySQLPrivilege) decodeColumnsPrivTableRow(row *ast.Row, fs []*ast.Resul
case f.ColumnAsName.L == "timestamp":
value.Timestamp, _ = d.GetMysqlTime().Time.GoTime(time.Local)
case f.ColumnAsName.L == "column_priv":
priv, err := decodeSetToPrivilege(d.GetMysqlSet())
if err != nil {
return errors.Trace(err)
}
value.ColumnPriv = priv
value.ColumnPriv = decodeSetToPrivilege(d.GetMysqlSet())
}
}
p.ColumnsPriv = append(p.ColumnsPriv, value)
return nil
}

func decodeSetToPrivilege(s types.Set) (mysql.PrivilegeType, error) {
func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType {
var ret mysql.PrivilegeType
if s.Name == "" {
return ret, nil
return ret
}
for _, str := range strings.Split(s.Name, ",") {
priv, ok := mysql.SetStr2Priv[str]
if !ok {
return ret, errInvalidPrivilegeType
log.Warn("unsupported privilege type:", str)
continue
}
ret |= priv
}
return ret, nil
return ret
}

func (record *userRecord) match(user, host string) bool {
Expand Down
69 changes: 69 additions & 0 deletions privilege/privileges/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,72 @@ func (s *testCacheSuite) TestCaseInsensitive(c *C) {
c.Assert(p.RequestVerification("genius", "127.0.0.1", "TCTRAIN", "TCTRAINORDER", "", mysql.SelectPriv), IsTrue)
c.Assert(p.RequestVerification("genius", "127.0.0.1", "tctrain", "tctrainorder", "", mysql.SelectPriv), IsTrue)
}

func (s *testCacheSuite) TestAfterSyncMySQLUser(c *C) {
privileges.Enable = true
store, err := tidb.NewStore("memory://sync_mysql_user")
c.Assert(err, IsNil)
domain, err := tidb.BootstrapSession(store)
c.Assert(err, IsNil)
defer domain.Close()

se, err := tidb.CreateSession(store)
c.Assert(err, IsNil)
defer se.Close()

mustExec(c, se, "DROP TABLE mysql.user;")
mustExec(c, se, "USE mysql;")
mustExec(c, se, `CREATE TABLE user (
Host char(60) COLLATE utf8_bin NOT NULL DEFAULT '',
User char(16) COLLATE utf8_bin NOT NULL DEFAULT '',
Password char(41) CHARACTER SET latin1 COLLATE latin1_bin NOT NULL DEFAULT '',
Select_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Insert_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Update_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Delete_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Create_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Drop_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Reload_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Shutdown_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Process_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
File_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Grant_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
References_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Index_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Alter_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Show_db_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Super_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Create_tmp_table_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Lock_tables_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Execute_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Repl_slave_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Repl_client_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Create_view_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Show_view_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Create_routine_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Alter_routine_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Create_user_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Event_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Trigger_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
Create_tablespace_priv enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
ssl_type enum('','ANY','X509','SPECIFIED') CHARACTER SET utf8 NOT NULL DEFAULT '',
ssl_cipher blob NOT NULL,
x509_issuer blob NOT NULL,
x509_subject blob NOT NULL,
max_questions int(11) unsigned NOT NULL DEFAULT '0',
max_updates int(11) unsigned NOT NULL DEFAULT '0',
max_connections int(11) unsigned NOT NULL DEFAULT '0',
max_user_connections int(11) unsigned NOT NULL DEFAULT '0',
plugin char(64) COLLATE utf8_bin DEFAULT 'mysql_native_password',
authentication_string text COLLATE utf8_bin,
password_expired enum('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N',
PRIMARY KEY (Host,User)
) ENGINE=MyISAM DEFAULT CHARSET=utf8 COLLATE=utf8_bin COMMENT='Users and global privileges';`)
mustExec(c, se, `INSERT INTO user VALUES ('localhost','root','','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','','','','',0,0,0,0,'mysql_native_password','','N');
`)
var p privileges.MySQLPrivilege
err = p.LoadUserTable(se)
c.Assert(err, IsNil)
// MySQL mysql.user table schema is not identical to TiDB, check it doesn't break privilege.
c.Assert(p.RequestVerification("root", "localhost", "test", "", "", mysql.SelectPriv), IsTrue)
}
6 changes: 3 additions & 3 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ const userTablePrivColumnStartIndex = 3
const dbTablePrivColumnStartIndex = 3

func (p *UserPrivileges) loadGlobalPrivileges(ctx context.Context) error {
sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`,
sql := fmt.Sprintf(`SELECT Host,User,Password,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Alter_priv,Show_db_priv,Execute_priv,Index_priv,Create_user_priv FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`,
mysql.SystemDB, mysql.UserTable, p.privs.User, p.privs.Host)
rows, fs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
if err != nil {
Expand Down Expand Up @@ -348,7 +348,7 @@ func (p *UserPrivileges) loadGlobalPrivileges(ctx context.Context) error {
}

func (p *UserPrivileges) loadDBScopePrivileges(ctx context.Context) error {
sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`,
sql := fmt.Sprintf(`SELECT Host,DB,User,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Index_priv,Alter_priv,Execute_priv FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`,
mysql.SystemDB, mysql.DBTable, p.privs.User, p.privs.Host)
rows, fs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
if err != nil {
Expand Down Expand Up @@ -381,7 +381,7 @@ func (p *UserPrivileges) loadDBScopePrivileges(ctx context.Context) error {
}

func (p *UserPrivileges) loadTableScopePrivileges(ctx context.Context) error {
sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`,
sql := fmt.Sprintf(`SELECT Host,DB,User,Table_name,Grantor,Timestamp,Table_priv,Column_priv FROM %s.%s WHERE User="%s" AND (Host="%s" OR Host="%%");`,
mysql.SystemDB, mysql.TablePrivTable, p.privs.User, p.privs.Host)
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,15 @@ func runTestAuth(c *C) {
_, err = db.Query("USE mysql;")
c.Assert(err, NotNil, Commentf("Wrong password should be failed"))
db.Close()

// Test login use IP that not exists in mysql.user.
runTests(c, dsn, func(dbt *DBTest) {
dbt.mustExec(`CREATE USER 'xxx'@'localhost' IDENTIFIED BY 'yyy';`)
})
newDsn = "xxx:yyy@tcp(127.0.0.1:4001)/test?strict=true"
runTests(c, newDsn, func(dbt *DBTest) {
dbt.mustExec(`USE mysql;`)
})
}

func runTestIssues(c *C) {
Expand Down
31 changes: 25 additions & 6 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
goctx "context"
"encoding/json"
"fmt"
"net"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -707,14 +708,32 @@ func (s *session) Auth(user string, auth []byte, salt []byte) bool {
// Get user password.
name := strs[0]
host := strs[1]

checker := privilege.GetPrivilegeChecker(s)
if !checker.ConnectionVerification(name, host, auth, salt) {
log.Errorf("User connection verification failed %v", name)
return false

// Check IP.
if checker.ConnectionVerification(name, host, auth, salt) {
s.sessionVars.User = name + "@" + host
return true
}

// Check Hostname.
for _, addr := range getHostByIP(host) {
if checker.ConnectionVerification(name, addr, auth, salt) {
s.sessionVars.User = name + "@" + addr
return true
}
}

log.Errorf("User connection verification failed %v", user)
return false
}

func getHostByIP(ip string) []string {
if ip == "127.0.0.1" {
return []string{"localhost"}
}
s.sessionVars.User = user
return true
addrs, _ := net.LookupAddr(ip)
return addrs
}

// Some vars name for debug.
Expand Down