Skip to content

Commit

Permalink
*: refactor ExecuteInternal to return single resultset (#22546) (#22640)
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-srebot authored Mar 5, 2021
1 parent 38f9bdd commit 649d0e0
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 64 deletions.
23 changes: 9 additions & 14 deletions bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,16 +661,14 @@ func (h *BindHandle) CaptureBaselines() {
func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) {
oriVals := sctx.GetSessionVars().UsePlanBaselines
sctx.GetSessionVars().UsePlanBaselines = false
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
sctx.GetSessionVars().UsePlanBaselines = oriVals
if len(recordSets) > 0 {
defer terror.Log(recordSets[0].Close())
}
if err != nil {
return "", err
}
chk := recordSets[0].NewChunk()
err = recordSets[0].Next(context.TODO(), chk)
defer terror.Call(rs.Close)
chk := rs.NewChunk()
err = rs.Next(context.TODO(), chk)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -873,23 +871,20 @@ func runSQL(ctx context.Context, sctx sessionctx.Context, sql string, resultChan
resultChan <- fmt.Errorf("run sql panicked: %v", string(buf))
}
}()
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
if err != nil {
if len(recordSets) > 0 {
terror.Call(recordSets[0].Close)
}
terror.Call(rs.Close)
resultChan <- err
return
}
recordSet := recordSets[0]
chk := recordSets[0].NewChunk()
chk := rs.NewChunk()
for {
err = recordSet.Next(ctx, chk)
err = rs.Next(ctx, chk)
if err != nil || chk.NumRows() == 0 {
break
}
}
terror.Call(recordSets[0].Close)
terror.Call(rs.Close)
resultChan <- err
}

Expand Down
12 changes: 5 additions & 7 deletions session/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -1160,24 +1160,22 @@ func upgradeToVer51(s Session, ver int64) {
mustExecute(s, "COMMIT")
}()
mustExecute(s, h.LockBindInfoSQL())
var recordSets []sqlexec.RecordSet
recordSets, err = s.ExecuteInternal(context.Background(),
var rs sqlexec.RecordSet
rs, err = s.ExecuteInternal(context.Background(),
`SELECT bind_sql, default_db, status, create_time, charset, collation, source
FROM mysql.bind_info
WHERE source != 'builtin'
ORDER BY update_time DESC`)
if err != nil {
logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err))
}
if len(recordSets) > 0 {
defer terror.Call(recordSets[0].Close)
}
req := recordSets[0].NewChunk()
defer terror.Call(rs.Close)
req := rs.NewChunk()
iter := chunk.NewIterator4Chunk(req)
p := parser.New()
now := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3)
for {
err = recordSets[0].Next(context.TODO(), req)
err = rs.Next(context.TODO(), req)
if err != nil {
logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err))
}
Expand Down
46 changes: 11 additions & 35 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ type Session interface {
// Parse is deprecated, use ParseWithParams() instead.
Parse(ctx context.Context, sql string) ([]ast.StmtNode, error)
// ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements.
ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error)
ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error)
String() string // String is used to debug.
CommitTxn(context.Context) error
RollbackTxn(context.Context)
Expand Down Expand Up @@ -860,37 +860,17 @@ func (s *session) ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast
func execRestrictedSQL(ctx context.Context, se *session, sql string) ([]chunk.Row, []*ast.ResultField, error) {
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
startTime := time.Now()
recordSets, err := se.ExecuteInternal(ctx, sql)
defer func() {
for _, rs := range recordSets {
closeErr := rs.Close()
if closeErr != nil && err == nil {
err = closeErr
}
}
}()
if err != nil {
rs, err := se.ExecuteInternal(ctx, sql)
if err != nil || rs == nil {
return nil, nil, err
}

var (
rows []chunk.Row
fields []*ast.ResultField
)
// Execute all recordset, take out the first one as result.
for i, rs := range recordSets {
tmp, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}

if i == 0 {
rows = tmp
fields = rs.Fields()
}
defer terror.Call(rs.Close)
rows, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}
metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds())
return rows, fields, err
return rows, rs.Fields(), err
}

func createSessionFunc(store kv.Storage) pools.Factory {
Expand Down Expand Up @@ -1123,7 +1103,7 @@ func (rs *execStmtResult) Close() error {
return finishStmt(context.Background(), se, err, rs.sql)
}

func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (recordSets []sqlexec.RecordSet, err error) {
func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (rs sqlexec.RecordSet, err error) {
origin := s.sessionVars.InRestrictedSQL
s.sessionVars.InRestrictedSQL = true
defer func() {
Expand All @@ -1142,15 +1122,11 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter
return nil, err
}

rs, err := s.ExecuteStmt(ctx, stmt)
rs, err = s.ExecuteStmt(ctx, stmt)
if err != nil {
s.sessionVars.StmtCtx.AppendError(err)
}
if rs == nil {
return nil, err
}

return []sqlexec.RecordSet{rs}, err
return rs, err
}

func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) {
Expand Down
8 changes: 3 additions & 5 deletions store/tikv/gcworker/gc_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1733,14 +1733,12 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) {
se := createSession(w.store)
defer se.Close()
rs, err := se.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name=%? FOR UPDATE`, key)
if len(rs) > 0 {
defer terror.Call(rs[0].Close)
}
if err != nil {
return "", errors.Trace(err)
}
req := rs[0].NewChunk()
err = rs[0].Next(ctx, req)
defer terror.Call(rs.Close)
req := rs.NewChunk()
err = rs.Next(ctx, req)
if err != nil {
return "", errors.Trace(err)
}
Expand Down
4 changes: 2 additions & 2 deletions util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet,
}

// ExecuteInternal implements sqlexec.SQLExecutor ExecuteInternal interface.
func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Support.")
func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Supported.")
}

type mockDDLOwnerChecker struct{}
Expand Down
2 changes: 1 addition & 1 deletion util/sqlexec/restricted_sql_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func ExecOptionWithSnapshot(snapshot uint64) OptionFuncAlias {
type SQLExecutor interface {
Execute(ctx context.Context, sql string) ([]RecordSet, error)
// ExecuteInternal means execute sql as the internal sql.
ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]RecordSet, error)
ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (RecordSet, error)
}

// SQLParser is an interface provides parsing sql statement.
Expand Down

0 comments on commit 649d0e0

Please sign in to comment.