From 0896d1507b8bdda00228f574f105f52e71904985 Mon Sep 17 00:00:00 2001 From: lysu Date: Thu, 4 Apr 2019 13:34:41 +0800 Subject: [PATCH] plugin: add audit plugin extension point (#9136) (#9954) --- cmd/pluginpkg/pluginpkg.go | 15 ++- executor/adapter.go | 22 ++++ executor/set.go | 11 ++ planner/core/optimizer.go | 2 + planner/core/planbuilder.go | 21 ++++ plugin/audit.go | 87 ++++++++++++++ plugin/conn_ip_example/conn_ip_example.go | 4 +- .../conn_ip_example/conn_ip_example_test.go | 20 ++-- plugin/conn_ip_example/manifest.toml | 2 +- plugin/plugin.go | 81 +++++++++---- plugin/spi.go | 26 ++-- plugin/spi_test.go | 5 +- server/conn.go | 57 +++++++-- server/driver.go | 2 + server/driver_tidb.go | 5 + server/server.go | 111 +++++++++++++++++- session/session.go | 18 ++- sessionctx/stmtctx/stmtctx.go | 7 ++ sessionctx/variable/session.go | 23 ++++ util/sys/linux/sys_linux.go | 40 +++++++ util/sys/linux/sys_other.go | 24 ++++ util/sys/linux/sys_test.go | 29 +++++ 22 files changed, 541 insertions(+), 71 deletions(-) create mode 100644 plugin/audit.go create mode 100644 util/sys/linux/sys_linux.go create mode 100644 util/sys/linux/sys_other.go create mode 100644 util/sys/linux/sys_test.go diff --git a/cmd/pluginpkg/pluginpkg.go b/cmd/pluginpkg/pluginpkg.go index e1b1db5a3dba3..335a8f5b93a49 100644 --- a/cmd/pluginpkg/pluginpkg.go +++ b/cmd/pluginpkg/pluginpkg.go @@ -62,9 +62,18 @@ func PluginManifest() *plugin.Manifest { }, {{end}} }, - Validate: {{.validate}}, - OnInit: {{.onInit}}, - OnShutdown: {{.onShutdown}}, + {{if .validate }} + Validate: {{.validate}}, + {{end}} + {{if .onInit }} + OnInit: {{.onInit}}, + {{end}} + {{if .onShutdown }} + OnShutdown: {{.onShutdown}}, + {{end}} + {{if .onFlush }} + OnFlush: {{.onFlush}}, + {{end}} }, {{range .export}} {{.extPoint}}: {{.impl}}, diff --git a/executor/adapter.go b/executor/adapter.go index fa43702946a9d..a01e1d5f37c63 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -32,12 +32,14 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sqlexec" log "github.com/sirupsen/logrus" + "go.uber.org/zap" "golang.org/x/net/context" ) @@ -124,6 +126,7 @@ func (a *recordSet) Close() error { if a.processinfo != nil { a.processinfo.SetProcessInfo("") } + a.stmt.logAudit() return errors.Trace(err) } @@ -286,6 +289,7 @@ func (a *ExecStmt) handleNoDelayExecutor(ctx context.Context, sctx sessionctx.Co } } a.LogSlowQuery(txnTS, err == nil) + a.logAudit() }() err = e.Next(ctx, e.newFirstChunk()) @@ -349,6 +353,24 @@ func (a *ExecStmt) buildExecutor(ctx sessionctx.Context) (Executor, error) { // QueryReplacer replaces new line and tab for grep result including query string. var QueryReplacer = strings.NewReplacer("\r", " ", "\n", " ", "\t", " ") +func (a *ExecStmt) logAudit() { + sessVars := a.Ctx.GetSessionVars() + if sessVars.InRestrictedSQL { + return + } + err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + audit := plugin.DeclareAuditManifest(p.Manifest) + if audit.OnGeneralEvent != nil { + cmd := mysql.Command2Str[byte(atomic.LoadUint32(&a.Ctx.GetSessionVars().CommandValue))] + audit.OnGeneralEvent(context.Background(), sessVars, plugin.Log, cmd) + } + return nil + }) + if err != nil { + logutil.Logger(context.Background()).Error("log audit log failure", zap.Error(err)) + } +} + // LogSlowQuery is used to print the slow query in the log files. func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool) { level := log.GetLevel() diff --git a/executor/set.go b/executor/set.go index 9ed9b5424413c..0e22422b9b87c 100644 --- a/executor/set.go +++ b/executor/set.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" @@ -139,6 +140,16 @@ func (e *SetExecutor) setSysVariable(name string, v *expression.VarAssignment) e if err != nil { return errors.Trace(err) } + err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + auditPlugin := plugin.DeclareAuditManifest(p.Manifest) + if auditPlugin.OnGlobalVariableEvent != nil { + auditPlugin.OnGlobalVariableEvent(context.Background(), e.ctx.GetSessionVars(), name, svalue) + } + return nil + }) + if err != nil { + return err + } } else { // Set session scope system variable. if sysVar.Scope&variable.ScopeSession == 0 { diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index ba428cebb59ee..439da9df369f8 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -76,6 +76,8 @@ func Optimize(ctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) ( return nil, errors.Trace(err) } + ctx.GetSessionVars().StmtCtx.Tables = builder.GetDBTableInfo() + // Maybe it's better to move this to Preprocess, but check privilege need table // information, which is collected into visitInfo during logical plan builder. if pm := privilege.GetPrivilegeManager(ctx); pm != nil { diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index f2a19e8eaf11b..fb43eb371e802 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/parser_driver" @@ -324,6 +325,26 @@ func (b *planBuilder) detectSelectAgg(sel *ast.SelectStmt) bool { return false } +// GetDBTableInfo gets the accessed dbs and tables info. +func (b *planBuilder) GetDBTableInfo() []stmtctx.TableEntry { + var tables []stmtctx.TableEntry + existsFunc := func(tbls []stmtctx.TableEntry, tbl *stmtctx.TableEntry) bool { + for _, t := range tbls { + if t == *tbl { + return true + } + } + return false + } + for _, v := range b.visitInfo { + tbl := &stmtctx.TableEntry{DB: v.db, Table: v.table} + if !existsFunc(tables, tbl) { + tables = append(tables, *tbl) + } + } + return tables +} + func getPathByIndexName(paths []*accessPath, idxName model.CIStr, tblInfo *model.TableInfo) *accessPath { var tablePath *accessPath for _, path := range paths { diff --git a/plugin/audit.go b/plugin/audit.go new file mode 100644 index 0000000000000..8ad556495ac62 --- /dev/null +++ b/plugin/audit.go @@ -0,0 +1,87 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "context" + + "github.com/pingcap/parser/auth" + "github.com/pingcap/tidb/sessionctx/variable" +) + +// GeneralEvent presents TiDB generate event. +type GeneralEvent byte + +const ( + // Log presents log event. + Log GeneralEvent = iota + // Error presents error event. + Error + // Result presents result event. + Result + // Status presents status event. + Status +) + +// ConnectionEvent presents TiDB connection event. +type ConnectionEvent byte + +const ( + // Connected presents new connection establish event(finish auth). + Connected ConnectionEvent = iota + // Disconnect presents disconnect event. + Disconnect + // ChangeUser presents change user. + ChangeUser + // PreAuth presents event before start auth. + PreAuth +) + +func (c ConnectionEvent) String() string { + switch c { + case Connected: + return "Connected" + case Disconnect: + return "Disconnect" + case ChangeUser: + return "ChangeUser" + case PreAuth: + return "PreAuth" + } + return "" +} + +// ParseEvent presents events happen around parser. +type ParseEvent byte + +const ( + // PreParse presents event before parse. + PreParse ParseEvent = 1 + iota + // PostParse presents event after parse. + PostParse +) + +// AuditManifest presents a sub-manifest that every audit plugin must provide. +type AuditManifest struct { + Manifest + // OnConnectionEvent will be called when TiDB receive or disconnect from client. + // return error will ignore and close current connection. + OnConnectionEvent func(ctx context.Context, identity *auth.UserIdentity, event ConnectionEvent, info *variable.ConnectionInfo) error + // OnGeneralEvent will be called during TiDB execution. + OnGeneralEvent func(ctx context.Context, sctx *variable.SessionVars, event GeneralEvent, cmd string) + // OnGlobalVariableEvent will be called when Change GlobalVariable. + OnGlobalVariableEvent func(ctx context.Context, sctx *variable.SessionVars, varName, varValue string) + // OnParseEvent will be called around parse logic. + OnParseEvent func(ctx context.Context, sctx *variable.SessionVars, event ParseEvent) error +} diff --git a/plugin/conn_ip_example/conn_ip_example.go b/plugin/conn_ip_example/conn_ip_example.go index 30cafec357fed..ebdc33ca5f114 100644 --- a/plugin/conn_ip_example/conn_ip_example.go +++ b/plugin/conn_ip_example/conn_ip_example.go @@ -40,8 +40,8 @@ func OnShutdown(ctx context.Context, manifest *plugin.Manifest) error { return nil } -// NotifyEvent implements TiDB Audit plugin's NotifyEvent SPI. -func NotifyEvent(ctx context.Context) error { +// OnGeneralEvent implements TiDB Audit plugin's OnGeneralEvent SPI. +func OnGeneralEvent(ctx context.Context, sctx *variable.SessionVars, event plugin.GeneralEvent, cmd byte, stmt string) error { fmt.Println("conn_ip_example notifiy called") fmt.Println("variable test: ", variable.GetSysVar("conn_ip_example_test_variable").Value) fmt.Printf("new connection by %s\n", ctx.Value("ip")) diff --git a/plugin/conn_ip_example/conn_ip_example_test.go b/plugin/conn_ip_example/conn_ip_example_test.go index 0d2bffe4f795a..76585eb80ae38 100644 --- a/plugin/conn_ip_example/conn_ip_example_test.go +++ b/plugin/conn_ip_example/conn_ip_example_test.go @@ -30,17 +30,23 @@ func Example_LoadRunShutdownPlugin() { PluginVarNames: &pluginVarNames, } - err := plugin.Init(ctx, cfg) + err := plugin.Load(ctx, cfg) if err != nil { panic(err) } - ps := plugin.GetByKind(plugin.Audit) - for _, auditPlugin := range ps { - if auditPlugin.State != plugin.Ready { - continue - } - plugin.DeclareAuditManifest(auditPlugin.Manifest).NotifyEvent(context.Background(), nil) + // load and start TiDB domain. + err = plugin.Init(ctx, cfg) + if err != nil { + panic(err) + } + + err = plugin.ForeachPlugin(plugin.Audit, func(auditPlugin *plugin.Plugin) error { + plugin.DeclareAuditManifest(auditPlugin.Manifest).OnGeneralEvent(context.Background(), nil, plugin.Log, "QUERY") + return nil + }) + if err != nil { + panic(err) } plugin.Shutdown(context.Background()) diff --git a/plugin/conn_ip_example/manifest.toml b/plugin/conn_ip_example/manifest.toml index 8f1a2c74ba7f8..b57badaf689f0 100644 --- a/plugin/conn_ip_example/manifest.toml +++ b/plugin/conn_ip_example/manifest.toml @@ -11,5 +11,5 @@ validate = "Validate" onInit = "OnInit" onShutdown = "OnShutdown" export = [ - {extPoint="NotifyEvent", impl="NotifyEvent"} + {extPoint="OnGeneralEvent", impl="OnGeneralEvent"} ] diff --git a/plugin/plugin.go b/plugin/plugin.go index 2066ca61357f5..fe4c64cc7369a 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -27,7 +27,9 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" - "github.com/prometheus/common/log" + "github.com/pingcap/tidb/util/logutil" + log "github.com/sirupsen/logrus" + "go.uber.org/zap" ) // pluginGlobal holds all global variables for plugin. @@ -148,9 +150,9 @@ func (p *Plugin) validate(ctx context.Context, tiPlugins *plugins, mode validate return nil } -// Init initializes the plugin and load plugin by config param. -// This method isn't thread-safe and must be called before any other plugin operation. -func Init(ctx context.Context, cfg Config) (err error) { +// Load load plugin by config param. +// This method need be called before domain init to inject global variable info during bootstrap. +func Load(ctx context.Context, cfg Config) (err error) { tiPlugins := &plugins{ plugins: make(map[Kind][]Plugin), versions: make(map[string]uint16), @@ -174,6 +176,7 @@ func Init(ctx context.Context, cfg Config) (err error) { _, dup := tiPlugins.versions[pName] if dup { if cfg.SkipWhenFail { + log.Warnf("duplicate load %s and ignored", pName) continue } err = errDuplicatePlugin.GenWithStackByArgs(pluginID) @@ -184,6 +187,7 @@ func Init(ctx context.Context, cfg Config) (err error) { plugin, err = loadOne(cfg.PluginDir, ID(pluginID)) if err != nil { if cfg.SkipWhenFail { + log.Warnf("load plugin %s failure and ignored, %v", pluginID, err) continue } return @@ -196,6 +200,7 @@ func Init(ctx context.Context, cfg Config) (err error) { for i := range tiPlugins.plugins[kind] { if err = tiPlugins.plugins[kind][i].validate(ctx, tiPlugins, initMode); err != nil { if cfg.SkipWhenFail { + log.Warnf("validate plugin %s fail and disable plugin, %v", tiPlugins.plugins[kind][i].Name, err) tiPlugins.plugins[kind][i].State = Disable err = nil continue @@ -227,30 +232,42 @@ func Init(ctx context.Context, cfg Config) (err error) { return } -// InitWatchLoops starts etcd watch loops for plugin that need watch. -func InitWatchLoops(etcdClient *clientv3.Client) { - if etcdClient == nil { - return - } +// Init initializes the loaded plugin by config param. +// This method must be called after `Load` but before any other plugin method call, so it call got TiDB domain info. +func Init(ctx context.Context, cfg Config) (err error) { tiPlugins := pluginGlobal.plugins() + if tiPlugins == nil { + return nil + } for kind := range tiPlugins.plugins { for i := range tiPlugins.plugins[kind] { - if tiPlugins.plugins[kind][i].OnFlush == nil { - continue + p := tiPlugins.plugins[kind][i] + if err = p.OnInit(ctx, p.Manifest); err != nil { + if cfg.SkipWhenFail { + log.Warnf("call Plugin %s OnInit failure, err: %v", p.Name, err) + tiPlugins.plugins[kind][i].State = Disable + err = nil + continue + } + return } - const pluginWatchPrefix = "/tidb/plugins/" - ctx, cancel := context.WithCancel(context.Background()) - watcher := &flushWatcher{ - ctx: ctx, - cancel: cancel, - path: pluginWatchPrefix + tiPlugins.plugins[kind][i].Name, - etcd: etcdClient, - manifest: tiPlugins.plugins[kind][i].Manifest, + if p.OnFlush != nil && cfg.EtcdClient != nil { + const pluginWatchPrefix = "/tidb/plugins/" + ctx, cancel := context.WithCancel(context.Background()) + watcher := &flushWatcher{ + ctx: ctx, + cancel: cancel, + path: pluginWatchPrefix + tiPlugins.plugins[kind][i].Name, + etcd: cfg.EtcdClient, + manifest: tiPlugins.plugins[kind][i].Manifest, + } + tiPlugins.plugins[kind][i].flushWatcher = watcher + go util.WithRecovery(watcher.watchLoop, nil) } - tiPlugins.plugins[kind][i].flushWatcher = watcher - go util.WithRecovery(watcher.watchLoop, nil) + tiPlugins.plugins[kind][i].State = Ready } } + return } type flushWatcher struct { @@ -325,6 +342,8 @@ func Shutdown(ctx context.Context) { p.flushWatcher.cancel() } if err := p.OnShutdown(ctx, p.Manifest); err != nil { + logutil.Logger(ctx).Error("call OnShutdown failure", + zap.String("pluginName", p.Name), zap.Error(err)) } } } @@ -348,13 +367,23 @@ func Get(kind Kind, name string) *Plugin { return nil } -// GetByKind finds and returns plugin by kind parameters. -func GetByKind(kind Kind) []Plugin { +// ForeachPlugin loops all ready plugins. +func ForeachPlugin(kind Kind, fn func(plugin *Plugin) error) error { plugins := pluginGlobal.plugins() if plugins == nil { return nil } - return plugins.plugins[kind] + for i := range plugins.plugins[kind] { + p := &plugins.plugins[kind][i] + if p.State != Ready { + continue + } + err := fn(p) + if err != nil { + return err + } + } + return nil } // GetAll finds and returns all plugins. @@ -369,8 +398,8 @@ func GetAll() map[Kind][]Plugin { // NotifyFlush notify plugins to do flush logic. func NotifyFlush(dom *domain.Domain, pluginName string) error { p := getByName(pluginName) - if p == nil || p.Manifest.flushWatcher == nil { - return errors.Errorf("plugin %s doesn't exists or unsupported flush", pluginName) + if p == nil || p.Manifest.flushWatcher == nil || p.State != Ready { + return errors.Errorf("plugin %s doesn't exists or unsupported flush or doesn't start with PD", pluginName) } _, err := dom.GetEtcdClient().KV.Put(context.Background(), p.Manifest.flushWatcher.path, "") if err != nil { diff --git a/plugin/spi.go b/plugin/spi.go index 2613b419e714d..8be7f6253f52e 100644 --- a/plugin/spi.go +++ b/plugin/spi.go @@ -39,11 +39,21 @@ type Manifest struct { License string BuildTime string SysVars map[string]*variable.SysVar - Validate func(ctx context.Context, manifest *Manifest) error - OnInit func(ctx context.Context, manifest *Manifest) error - OnShutdown func(ctx context.Context, manifest *Manifest) error - OnFlush func(ctx context.Context, manifest *Manifest) error - flushWatcher *flushWatcher + // Validate defines the validate logic for plugin. + // returns error will stop load plugin process and TiDB startup. + Validate func(ctx context.Context, manifest *Manifest) error + // OnInit defines the plugin init logic. + // it will be called after domain init. + // return error will stop load plugin process and TiDB startup. + OnInit func(ctx context.Context, manifest *Manifest) error + // OnShutDown defines the plugin cleanup logic. + // return error will write log and continue shutdown. + OnShutdown func(ctx context.Context, manifest *Manifest) error + // OnFlush defines flush logic after executed `flush tidb plugins`. + // it will be called after OnInit. + // return error will write log and continue watch following flush. + OnFlush func(ctx context.Context, manifest *Manifest) error + flushWatcher *flushWatcher } // ExportManifest exports a manifest to TiDB as a known format. @@ -53,12 +63,6 @@ func ExportManifest(m interface{}) *Manifest { return (*Manifest)(unsafe.Pointer(v.Pointer())) } -// AuditManifest presents a sub-manifest that every audit plugin must provide. -type AuditManifest struct { - Manifest - NotifyEvent func(ctx context.Context, sctx *variable.SessionVars) error -} - // AuthenticationManifest presents a sub-manifest that every audit plugin must provide. type AuthenticationManifest struct { Manifest diff --git a/plugin/spi_test.go b/plugin/spi_test.go index 98e676acfcc3a..efdd8b53802cb 100644 --- a/plugin/spi_test.go +++ b/plugin/spi_test.go @@ -36,15 +36,14 @@ func TestExportManifest(t *testing.T) { return nil }, }, - NotifyEvent: func(ctx context.Context, sctx *variable.SessionVars) error { + OnGeneralEvent: func(ctx context.Context, sctx *variable.SessionVars, event plugin.GeneralEvent, cmd string) { callRecorder.NotifyEventCalled = true - return nil }, } exported := plugin.ExportManifest(manifest) exported.OnInit(context.Background(), exported) audit := plugin.DeclareAuditManifest(exported) - audit.NotifyEvent(context.Background(), nil) + audit.OnGeneralEvent(context.Background(), nil, plugin.Log, "QUERY") if !callRecorder.NotifyEventCalled || !callRecorder.OnInitCalled { t.Fatalf("export test failure") } diff --git a/server/conn.go b/server/conn.go index 38f8169ab62e0..8bd173f0d6ce2 100644 --- a/server/conn.go +++ b/server/conn.go @@ -39,6 +39,7 @@ import ( "crypto/tls" "encoding/binary" "fmt" + "github.com/pingcap/tidb/plugin" "io" "net" "runtime" @@ -103,6 +104,9 @@ type clientConn struct { ctx QueryCtx // an interface to execute sql statements. attrs map[string]string // attributes parsed from client handshake response, not used for now. status int32 // dispatching/reading/shutdown/waitshutdown + peerHost string // peer host + peerPort string // peer port + lastCode uint16 // last error code // mu is used for cancelling the execution of current transaction. mu struct { @@ -402,14 +406,9 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte) error { if err != nil { return errors.Trace(err) } - host := variable.DefHostname - if !cc.server.isUnixSocket() { - addr := cc.bufReadConn.RemoteAddr().String() - // Do Auth. - host, _, err = net.SplitHostPort(addr) - if err != nil { - return errors.Trace(errAccessDenied.GenWithStackByArgs(cc.user, addr, "YES")) - } + host, err := cc.PeerHost() + if err != nil { + return err } if !cc.ctx.Auth(&auth.UserIdentity{Username: cc.user, Hostname: host}, authData, cc.salt) { return errors.Trace(errAccessDenied.GenWithStackByArgs(cc.user, host, "YES")) @@ -424,6 +423,27 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte) error { return nil } +func (cc *clientConn) PeerHost() (host string, err error) { + if len(cc.peerHost) > 0 { + return cc.peerHost, nil + } + host = variable.DefHostname + if cc.server.isUnixSocket() { + cc.peerHost = host + return + } + addr := cc.bufReadConn.RemoteAddr().String() + var port string + host, port, err = net.SplitHostPort(addr) + if err != nil { + err = errAccessDenied.GenWithStackByArgs(cc.user, addr, "") + return + } + cc.peerHost = host + cc.peerPort = port + return +} + // Run reads client query and writes query result to client in for loop, if there is a panic during query handling, // it will be recovered and log the panic error. // This function returns and the connection is closed if there is an IO error or there is a panic. @@ -614,6 +634,10 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { span.Finish() }() + if cmd < mysql.ComEnd { + cc.ctx.SetCommandValue(cmd) + } + switch cmd { case mysql.ComSleep: // TODO: According to mysql document, this command is supposed to be used only internally. @@ -707,6 +731,7 @@ func (cc *clientConn) writeError(e error) error { m = mysql.NewErrf(mysql.ErrUnknown, "%s", e.Error()) } + cc.lastCode = m.Code data := cc.alloc.AllocWithLen(4, 16+len(m.Message)) data = append(data, mysql.ErrHeader) data = append(data, byte(m.Code), byte(m.Code>>8)) @@ -1158,5 +1183,21 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error { if err != nil { return errors.Trace(err) } + + err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + authPlugin := plugin.DeclareAuditManifest(p.Manifest) + if authPlugin.OnConnectionEvent != nil { + connInfo := cc.connectInfo() + err = authPlugin.OnConnectionEvent(context.Background(), &auth.UserIdentity{Hostname: connInfo.Host}, plugin.ChangeUser, connInfo) + if err != nil { + return errors.Trace(err) + } + } + return nil + }) + if err != nil { + return err + } + return cc.writeOK() } diff --git a/server/driver.go b/server/driver.go index df8bb2b00e36e..bc19a1eece7b0 100644 --- a/server/driver.go +++ b/server/driver.go @@ -86,6 +86,8 @@ type QueryCtx interface { // GetSessionVars return SessionVars. GetSessionVars() *variable.SessionVars + SetCommandValue(command byte) + SetSessionManager(util.SessionManager) } diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 20991846579e1..130869a46ead9 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -332,6 +332,11 @@ func (tc *TiDBContext) GetSessionVars() *variable.SessionVars { return tc.session.GetSessionVars() } +// SetCommandValue implements QueryCtx SetCommandValue method. +func (tc *TiDBContext) SetCommandValue(command byte) { + tc.session.SetCommandValue(command) +} + type tidbResultSet struct { recordSet sqlexec.RecordSet columns []*ColumnInfo diff --git a/server/server.go b/server/server.go index 8bf26eee20714..5ce401ce11433 100644 --- a/server/server.go +++ b/server/server.go @@ -39,26 +39,49 @@ import ( "net/http" // For pprof _ "net/http/pprof" + "os" + "os/user" "sync" "sync/atomic" "time" "github.com/blacktear23/go-proxyprotocol" "github.com/pingcap/errors" + "github.com/pingcap/parser/auth" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sys/linux" + log "github.com/sirupsen/logrus" "go.uber.org/zap" ) var ( baseConnID uint32 + serverPID int + osUser string + osVersion string ) +func init() { + serverPID = os.Getpid() + currentUser, err := user.Current() + if err != nil { + osUser = "" + } else { + osUser = currentUser.Name + } + osVersion, err = linux.OSVersion() + if err != nil { + osVersion = "" + } +} + var ( errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type") errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length") @@ -263,7 +286,29 @@ func (s *Server) Run() error { terror.Log(errors.Trace(err)) break } - go s.onConn(conn) + clientConn := s.newConn(conn) + err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + authPlugin := plugin.DeclareAuditManifest(p.Manifest) + if authPlugin.OnConnectionEvent != nil { + host, err := clientConn.PeerHost() + if err != nil { + log.Info(err) + terror.Log(clientConn.Close()) + return errors.Trace(err) + } + err = authPlugin.OnConnectionEvent(context.Background(), &auth.UserIdentity{Hostname: host}, plugin.PreAuth, nil) + if err != nil { + log.Info(err) + terror.Log(clientConn.Close()) + return errors.Trace(err) + } + } + return nil + }) + if err != nil { + continue + } + go s.onConn(clientConn) } err := s.listener.Close() terror.Log(errors.Trace(err)) @@ -303,18 +348,17 @@ func (s *Server) Close() { } // onConn runs in its own goroutine, handles queries from this connection. -func (s *Server) onConn(c net.Conn) { - conn := s.newConn(c) +func (s *Server) onConn(conn *clientConn) { ctx := logutil.WithConnID(context.Background(), conn.connectionID) if err := conn.handshake(ctx); err != nil { // Some keep alive services will send request to TiDB and disconnect immediately. // So we only record metrics. metrics.HandShakeErrorCounter.Inc() - err = c.Close() + err = conn.Close() terror.Log(errors.Trace(err)) return } - logutil.Logger(ctx).Info("new connection", zap.String("remoteAddr", c.RemoteAddr().String())) + logutil.Logger(ctx).Info("new connection", zap.String("remoteAddr", conn.bufReadConn.RemoteAddr().String())) defer func() { logutil.Logger(ctx).Info("close connection") }() @@ -324,7 +368,64 @@ func (s *Server) onConn(c net.Conn) { s.rwlock.Unlock() metrics.ConnGauge.Set(float64(connections)) + err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + authPlugin := plugin.DeclareAuditManifest(p.Manifest) + if authPlugin.OnConnectionEvent != nil { + connInfo := conn.connectInfo() + return authPlugin.OnConnectionEvent(context.Background(), conn.ctx.GetSessionVars().User, plugin.Connected, connInfo) + } + return nil + }) + if err != nil { + return + } + + connectedTime := time.Now() conn.Run(ctx) + + err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + authPlugin := plugin.DeclareAuditManifest(p.Manifest) + if authPlugin.OnConnectionEvent != nil { + connInfo := conn.connectInfo() + connInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond) + err := authPlugin.OnConnectionEvent(context.Background(), conn.ctx.GetSessionVars().User, plugin.Disconnect, connInfo) + if err != nil { + log.Warnf("call Plugin %s OnConnectionEvent(Disconnect) failure, err: %v", authPlugin.Name, err) + } + } + return nil + }) + if err != nil { + return + } +} + +func (cc *clientConn) connectInfo() *variable.ConnectionInfo { + connType := "Socket" + if cc.server.isUnixSocket() { + connType = "UnixSocket" + } else if cc.tlsConn != nil { + connType = "SSL/TLS" + } + connInfo := &variable.ConnectionInfo{ + ConnectionID: cc.connectionID, + ConnectionType: connType, + Host: cc.peerHost, + ClientIP: cc.peerHost, + ClientPort: cc.peerPort, + ServerID: 1, + ServerPort: int(cc.server.cfg.Port), + Duration: 0, + User: cc.user, + ServerOSLoginUser: osUser, + OSVersion: osVersion, + ClientVersion: "", + ServerVersion: mysql.TiDBReleaseVersion, + SSLVersion: "v1.2.0", // for current go version + PID: serverPID, + DB: cc.dbname, + } + return connInfo } // ShowProcessList implements the SessionManager interface. diff --git a/session/session.go b/session/session.go index c8ecd91abc569..fd5970602c553 100644 --- a/session/session.go +++ b/session/session.go @@ -92,6 +92,7 @@ type Session interface { PrepareTxnCtx(context.Context) // FieldList returns fields list of a table. FieldList(tableName string) (fields []*ast.ResultField, err error) + SetCommandValue(byte) } var ( @@ -274,6 +275,10 @@ func (s *session) FieldList(tableName string) ([]*ast.ResultField, error) { return fields, nil } +func (s *session) SetCommandValue(command byte) { + atomic.StoreUint32(&s.sessionVars.CommandValue, uint32(command)) +} + func (s *session) doCommit(ctx context.Context) error { if !s.txn.Valid() { return nil @@ -1206,7 +1211,7 @@ func loadSystemTZ(se *session) (string, error) { func BootstrapSession(store kv.Storage) (*domain.Domain, error) { cfg := config.GetGlobalConfig() if len(cfg.Plugin.Load) > 0 { - err := plugin.Init(context.Background(), plugin.Config{ + err := plugin.Load(context.Background(), plugin.Config{ Plugins: strings.Split(cfg.Plugin.Load, ","), PluginDir: cfg.Plugin.Dir, GlobalSysVar: &variable.SysVars, @@ -1246,6 +1251,13 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { } } + if len(cfg.Plugin.Load) > 0 { + err := plugin.Init(context.Background(), plugin.Config{EtcdClient: dom.GetEtcdClient()}) + if err != nil { + return nil, errors.Trace(err) + } + } + se1, err := createSession(store) if err != nil { return nil, errors.Trace(err) @@ -1255,10 +1267,6 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { return nil, errors.Trace(err) } - if len(cfg.Plugin.Load) > 0 { - plugin.InitWatchLoops(dom.GetEtcdClient()) - } - if raw, ok := store.(domain.EtcdBackend); ok { err = raw.StartGCWorker() if err != nil { diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 825dffc8d9df1..fd014bc0a9fec 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -92,6 +92,13 @@ type StatementContext struct { TableIDs []int64 IndexIDs []int64 StmtType string + Tables []TableEntry +} + +// TableEntry presents table in db. +type TableEntry struct { + DB string + Table string } // AddAffectedRows adds affected rows. diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 4253cea8c5cd9..1144cf4fea238 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -310,6 +310,29 @@ type SessionVars struct { // ConstraintCheckInPlace indicates whether to check the constraint when the SQL executing. ConstraintCheckInPlace bool + + // CommandValue indicates which command current session is doing. + CommandValue uint32 +} + +// ConnectionInfo present connection used by audit. +type ConnectionInfo struct { + ConnectionID uint32 + ConnectionType string + Host string + ClientIP string + ClientPort string + ServerID int + ServerPort int + Duration float64 + User string + ServerOSLoginUser string + OSVersion string + ClientVersion string + ServerVersion string + SSLVersion string + PID int + DB string } // NewSessionVars creates a session vars object. diff --git a/util/sys/linux/sys_linux.go b/util/sys/linux/sys_linux.go new file mode 100644 index 0000000000000..c53bf77b31681 --- /dev/null +++ b/util/sys/linux/sys_linux.go @@ -0,0 +1,40 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. +// +build linux + +package linux + +import "syscall" + +// OSVersion returns version info of operation system. +// e.g. Linux 4.15.0-45-generic.x86_64 +func OSVersion() (osVersion string, err error) { + var un syscall.Utsname + err = syscall.Uname(&un) + if err != nil { + return + } + charsToString := func(ca []int8) string { + s := make([]byte, len(ca)) + var lens int + for ; lens < len(ca); lens++ { + if ca[lens] == 0 { + break + } + s[lens] = uint8(ca[lens]) + } + return string(s[0:lens]) + } + osVersion = charsToString(un.Sysname[:]) + " " + charsToString(un.Release[:]) + "." + charsToString(un.Machine[:]) + return +} diff --git a/util/sys/linux/sys_other.go b/util/sys/linux/sys_other.go new file mode 100644 index 0000000000000..98b4eae9749f9 --- /dev/null +++ b/util/sys/linux/sys_other.go @@ -0,0 +1,24 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. +// +build !linux + +package linux + +import "runtime" + +// OSVersion returns version info of operation system. +// for non-linux system will only return os and arch info. +func OSVersion() (osVersion string, err error) { + osVersion = runtime.GOOS + "." + runtime.GOARCH + return +} diff --git a/util/sys/linux/sys_test.go b/util/sys/linux/sys_test.go new file mode 100644 index 0000000000000..b9ccc2a823bec --- /dev/null +++ b/util/sys/linux/sys_test.go @@ -0,0 +1,29 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. +package linux_test + +import ( + "testing" + + "github.com/pingcap/tidb/util/sys/linux" +) + +func TestGetOSVersion(t *testing.T) { + osRelease, err := linux.OSVersion() + if err != nil { + t.Fatal(t) + } + if len(osRelease) == 0 { + t.Fatalf("counld not get os version") + } +}