Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: refactor ExecuteInternal to return single resultset (#22546) #22640

Merged
merged 8 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,16 +661,14 @@ func (h *BindHandle) CaptureBaselines() {
func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) {
oriVals := sctx.GetSessionVars().UsePlanBaselines
sctx.GetSessionVars().UsePlanBaselines = false
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
sctx.GetSessionVars().UsePlanBaselines = oriVals
if len(recordSets) > 0 {
defer terror.Log(recordSets[0].Close())
}
if err != nil {
return "", err
}
chk := recordSets[0].NewChunk()
err = recordSets[0].Next(context.TODO(), chk)
defer terror.Call(rs.Close)
chk := rs.NewChunk()
err = rs.Next(context.TODO(), chk)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -873,23 +871,20 @@ func runSQL(ctx context.Context, sctx sessionctx.Context, sql string, resultChan
resultChan <- fmt.Errorf("run sql panicked: %v", string(buf))
}
}()
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
if err != nil {
if len(recordSets) > 0 {
terror.Call(recordSets[0].Close)
}
terror.Call(rs.Close)
resultChan <- err
return
}
recordSet := recordSets[0]
chk := recordSets[0].NewChunk()
chk := rs.NewChunk()
for {
err = recordSet.Next(ctx, chk)
err = rs.Next(ctx, chk)
if err != nil || chk.NumRows() == 0 {
break
}
}
terror.Call(recordSets[0].Close)
terror.Call(rs.Close)
resultChan <- err
}

Expand Down
12 changes: 5 additions & 7 deletions session/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -1160,24 +1160,22 @@ func upgradeToVer51(s Session, ver int64) {
mustExecute(s, "COMMIT")
}()
mustExecute(s, h.LockBindInfoSQL())
var recordSets []sqlexec.RecordSet
recordSets, err = s.ExecuteInternal(context.Background(),
var rs sqlexec.RecordSet
rs, err = s.ExecuteInternal(context.Background(),
`SELECT bind_sql, default_db, status, create_time, charset, collation, source
FROM mysql.bind_info
WHERE source != 'builtin'
ORDER BY update_time DESC`)
if err != nil {
logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err))
}
if len(recordSets) > 0 {
defer terror.Call(recordSets[0].Close)
}
req := recordSets[0].NewChunk()
defer terror.Call(rs.Close)
req := rs.NewChunk()
iter := chunk.NewIterator4Chunk(req)
p := parser.New()
now := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3)
for {
err = recordSets[0].Next(context.TODO(), req)
err = rs.Next(context.TODO(), req)
if err != nil {
logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err))
}
Expand Down
46 changes: 11 additions & 35 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ type Session interface {
// Parse is deprecated, use ParseWithParams() instead.
Parse(ctx context.Context, sql string) ([]ast.StmtNode, error)
// ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements.
ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error)
ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error)
String() string // String is used to debug.
CommitTxn(context.Context) error
RollbackTxn(context.Context)
Expand Down Expand Up @@ -860,37 +860,17 @@ func (s *session) ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast
func execRestrictedSQL(ctx context.Context, se *session, sql string) ([]chunk.Row, []*ast.ResultField, error) {
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
startTime := time.Now()
recordSets, err := se.ExecuteInternal(ctx, sql)
defer func() {
for _, rs := range recordSets {
closeErr := rs.Close()
if closeErr != nil && err == nil {
err = closeErr
}
}
}()
if err != nil {
rs, err := se.ExecuteInternal(ctx, sql)
if err != nil || rs == nil {
return nil, nil, err
}

var (
rows []chunk.Row
fields []*ast.ResultField
)
// Execute all recordset, take out the first one as result.
for i, rs := range recordSets {
tmp, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}

if i == 0 {
rows = tmp
fields = rs.Fields()
}
defer terror.Call(rs.Close)
rows, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}
metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds())
return rows, fields, err
return rows, rs.Fields(), err
}

func createSessionFunc(store kv.Storage) pools.Factory {
Expand Down Expand Up @@ -1123,7 +1103,7 @@ func (rs *execStmtResult) Close() error {
return finishStmt(context.Background(), se, err, rs.sql)
}

func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (recordSets []sqlexec.RecordSet, err error) {
func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (rs sqlexec.RecordSet, err error) {
origin := s.sessionVars.InRestrictedSQL
s.sessionVars.InRestrictedSQL = true
defer func() {
Expand All @@ -1142,15 +1122,11 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter
return nil, err
}

rs, err := s.ExecuteStmt(ctx, stmt)
rs, err = s.ExecuteStmt(ctx, stmt)
if err != nil {
s.sessionVars.StmtCtx.AppendError(err)
}
if rs == nil {
return nil, err
}

return []sqlexec.RecordSet{rs}, err
return rs, err
}

func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) {
Expand Down
4 changes: 2 additions & 2 deletions util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet,
}

// ExecuteInternal implements sqlexec.SQLExecutor ExecuteInternal interface.
func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Support.")
func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Supported.")
}

type mockDDLOwnerChecker struct{}
Expand Down
2 changes: 1 addition & 1 deletion util/sqlexec/restricted_sql_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func ExecOptionWithSnapshot(snapshot uint64) OptionFuncAlias {
type SQLExecutor interface {
Execute(ctx context.Context, sql string) ([]RecordSet, error)
// ExecuteInternal means execute sql as the internal sql.
ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]RecordSet, error)
ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (RecordSet, error)
}

// SQLParser is an interface provides parsing sql statement.
Expand Down