diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 93d4b65391452..146eb0c96fa0c 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -661,16 +661,14 @@ func (h *BindHandle) CaptureBaselines() { func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) { oriVals := sctx.GetSessionVars().UsePlanBaselines sctx.GetSessionVars().UsePlanBaselines = false - recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql)) + rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql)) sctx.GetSessionVars().UsePlanBaselines = oriVals - if len(recordSets) > 0 { - defer terror.Log(recordSets[0].Close()) - } if err != nil { return "", err } - chk := recordSets[0].NewChunk() - err = recordSets[0].Next(context.TODO(), chk) + defer terror.Call(rs.Close) + chk := rs.NewChunk() + err = rs.Next(context.TODO(), chk) if err != nil { return "", err } @@ -873,23 +871,20 @@ func runSQL(ctx context.Context, sctx sessionctx.Context, sql string, resultChan resultChan <- fmt.Errorf("run sql panicked: %v", string(buf)) } }() - recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) + rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) if err != nil { - if len(recordSets) > 0 { - terror.Call(recordSets[0].Close) - } + terror.Call(rs.Close) resultChan <- err return } - recordSet := recordSets[0] - chk := recordSets[0].NewChunk() + chk := rs.NewChunk() for { - err = recordSet.Next(ctx, chk) + err = rs.Next(ctx, chk) if err != nil || chk.NumRows() == 0 { break } } - terror.Call(recordSets[0].Close) + terror.Call(rs.Close) resultChan <- err } diff --git a/session/bootstrap.go b/session/bootstrap.go index b85bb79e7de13..19834e428156b 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -1160,8 +1160,8 @@ func upgradeToVer51(s Session, ver int64) { mustExecute(s, "COMMIT") }() mustExecute(s, h.LockBindInfoSQL()) - var recordSets []sqlexec.RecordSet - recordSets, err = s.ExecuteInternal(context.Background(), + var rs sqlexec.RecordSet + rs, err = s.ExecuteInternal(context.Background(), `SELECT bind_sql, default_db, status, create_time, charset, collation, source FROM mysql.bind_info WHERE source != 'builtin' @@ -1169,15 +1169,13 @@ func upgradeToVer51(s Session, ver int64) { if err != nil { logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err)) } - if len(recordSets) > 0 { - defer terror.Call(recordSets[0].Close) - } - req := recordSets[0].NewChunk() + defer terror.Call(rs.Close) + req := rs.NewChunk() iter := chunk.NewIterator4Chunk(req) p := parser.New() now := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3) for { - err = recordSets[0].Next(context.TODO(), req) + err = rs.Next(context.TODO(), req) if err != nil { logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err)) } diff --git a/session/session.go b/session/session.go index 0c76f0015c9da..d05e914e59e89 100644 --- a/session/session.go +++ b/session/session.go @@ -105,7 +105,7 @@ type Session interface { // Parse is deprecated, use ParseWithParams() instead. Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) // ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements. - ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error) + ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error) String() string // String is used to debug. CommitTxn(context.Context) error RollbackTxn(context.Context) @@ -860,37 +860,17 @@ func (s *session) ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast func execRestrictedSQL(ctx context.Context, se *session, sql string) ([]chunk.Row, []*ast.ResultField, error) { ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) startTime := time.Now() - recordSets, err := se.ExecuteInternal(ctx, sql) - defer func() { - for _, rs := range recordSets { - closeErr := rs.Close() - if closeErr != nil && err == nil { - err = closeErr - } - } - }() - if err != nil { + rs, err := se.ExecuteInternal(ctx, sql) + if err != nil || rs == nil { return nil, nil, err } - - var ( - rows []chunk.Row - fields []*ast.ResultField - ) - // Execute all recordset, take out the first one as result. - for i, rs := range recordSets { - tmp, err := drainRecordSet(ctx, se, rs) - if err != nil { - return nil, nil, err - } - - if i == 0 { - rows = tmp - fields = rs.Fields() - } + defer terror.Call(rs.Close) + rows, err := drainRecordSet(ctx, se, rs) + if err != nil { + return nil, nil, err } metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds()) - return rows, fields, err + return rows, rs.Fields(), err } func createSessionFunc(store kv.Storage) pools.Factory { @@ -1123,7 +1103,7 @@ func (rs *execStmtResult) Close() error { return finishStmt(context.Background(), se, err, rs.sql) } -func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (recordSets []sqlexec.RecordSet, err error) { +func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (rs sqlexec.RecordSet, err error) { origin := s.sessionVars.InRestrictedSQL s.sessionVars.InRestrictedSQL = true defer func() { @@ -1142,15 +1122,11 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter return nil, err } - rs, err := s.ExecuteStmt(ctx, stmt) + rs, err = s.ExecuteStmt(ctx, stmt) if err != nil { s.sessionVars.StmtCtx.AppendError(err) } - if rs == nil { - return nil, err - } - - return []sqlexec.RecordSet{rs}, err + return rs, err } func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) { diff --git a/store/tikv/gcworker/gc_worker.go b/store/tikv/gcworker/gc_worker.go index b606dec2b1fac..41c7427779db5 100644 --- a/store/tikv/gcworker/gc_worker.go +++ b/store/tikv/gcworker/gc_worker.go @@ -1733,14 +1733,12 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { se := createSession(w.store) defer se.Close() rs, err := se.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name=%? FOR UPDATE`, key) - if len(rs) > 0 { - defer terror.Call(rs[0].Close) - } if err != nil { return "", errors.Trace(err) } - req := rs[0].NewChunk() - err = rs[0].Next(ctx, req) + defer terror.Call(rs.Close) + req := rs.NewChunk() + err = rs.Next(ctx, req) if err != nil { return "", errors.Trace(err) } diff --git a/util/mock/context.go b/util/mock/context.go index 13c878c822bed..f049a05beb842 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -62,8 +62,8 @@ func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet, } // ExecuteInternal implements sqlexec.SQLExecutor ExecuteInternal interface. -func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]sqlexec.RecordSet, error) { - return nil, errors.Errorf("Not Support.") +func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) { + return nil, errors.Errorf("Not Supported.") } type mockDDLOwnerChecker struct{} diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 92d6958a38667..cc23bda405b9f 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -88,7 +88,7 @@ func ExecOptionWithSnapshot(snapshot uint64) OptionFuncAlias { type SQLExecutor interface { Execute(ctx context.Context, sql string) ([]RecordSet, error) // ExecuteInternal means execute sql as the internal sql. - ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]RecordSet, error) + ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (RecordSet, error) } // SQLParser is an interface provides parsing sql statement.