Skip to content

Commit

Permalink
server,mysql: support server-side cursors (#6648)
Browse files Browse the repository at this point in the history
Implement server-side cursors by handling COM_STMT_FETCH command.
The client indicates that it wants to use cursor
by setting a flag in COM_STMT_EXECUTE
Please refer to https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
Subsequently, the client acquires result rows repeatedly by COM_STMT_FETCH,
which will carry stmt-id and fetch size.
Please refer to https://dev.mysql.com/doc/internals/en/com-stmt-fetch.html
This commit only support forward-only, read-only cursor
  • Loading branch information
mz1999 authored and coocood committed Jun 22, 2018
1 parent 48719d2 commit 1fbcc10
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 25 deletions.
5 changes: 5 additions & 0 deletions mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ const (
ServerPSOutParams uint16 = 0x1000
)

// HasCursorExistsFlag return true if cursor exists indicated by server status.
func HasCursorExistsFlag(serverStatus uint16) bool {
return serverStatus&ServerStatusCursorExists > 0
}

// Identifier length limitations.
// See https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
const (
Expand Down
17 changes: 17 additions & 0 deletions mysql/const_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,20 @@ func (s *testMySQLConstSuite) TestNoBackslashEscapesMode(c *C) {
r = tk.MustQuery("SELECT '\\\\'")
r.Check(testkit.Rows("\\\\"))
}

func (s *testMySQLConstSuite) TestServerStatus(c *C) {
tests := []struct {
arg uint16
IsCursorExists bool
}{
{0, false},
{mysql.ServerStatusInTrans | mysql.ServerStatusNoBackslashEscaped, false},
{mysql.ServerStatusCursorExists, true},
{mysql.ServerStatusCursorExists | mysql.ServerStatusLastRowSend, true},
}

for _, t := range tests {
ret := mysql.HasCursorExistsFlag(t.arg)
c.Assert(ret, Equals, t.IsCursorExists)
}
}
109 changes: 90 additions & 19 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import (
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util/arena"
"github.com/pingcap/tidb/util/auth"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/memory"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -550,6 +551,8 @@ func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) {
label = "StmtPrepare"
case mysql.ComStmtExecute:
label = "StmtExecute"
case mysql.ComStmtFetch:
label = "StmtFetch"
case mysql.ComStmtClose:
label = "StmtClose"
case mysql.ComStmtSendLongData:
Expand Down Expand Up @@ -620,6 +623,8 @@ func (cc *clientConn) dispatch(data []byte) error {
return cc.handleStmtPrepare(hack.String(data))
case mysql.ComStmtExecute:
return cc.handleStmtExecute(ctx1, data)
case mysql.ComStmtFetch:
return cc.handleStmtFetch(ctx1, data)
case mysql.ComStmtClose:
return cc.handleStmtClose(data)
case mysql.ComStmtSendLongData:
Expand Down Expand Up @@ -700,19 +705,17 @@ func (cc *clientConn) writeError(e error) error {

// writeEOF writes an EOF packet.
// Note this function won't flush the stream because maybe there are more
// packets following it, the "more" argument would indicates that case.
// If "more" is true, a mysql.ServerMoreResultsExists bit would be set
// packets following it.
// serverStatus, a flag bit represents server information
// in the packet.
func (cc *clientConn) writeEOF(more bool) error {
func (cc *clientConn) writeEOF(serverStatus uint16) error {
data := cc.alloc.AllocWithLen(4, 9)

data = append(data, mysql.EOFHeader)
if cc.capability&mysql.ClientProtocol41 > 0 {
data = dumpUint16(data, cc.ctx.WarningCount())
status := cc.ctx.Status()
if more {
status |= mysql.ServerMoreResultsExists
}
status |= serverStatus
data = dumpUint16(data, status)
}

Expand Down Expand Up @@ -860,7 +863,7 @@ func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) {
}
if rs != nil {
if len(rs) == 1 {
err = cc.writeResultset(ctx, rs[0], false, false)
err = cc.writeResultset(ctx, rs[0], false, 0, 0)
} else {
err = cc.writeMultiResultset(ctx, rs, false)
}
Expand Down Expand Up @@ -902,19 +905,23 @@ func (cc *clientConn) handleFieldList(sql string) (err error) {
return errors.Trace(err)
}
}
if err := cc.writeEOF(false); err != nil {
if err := cc.writeEOF(0); err != nil {
return errors.Trace(err)
}
return errors.Trace(cc.flush())
}

// writeResultset writes data into a resultset and uses rs.Next to get row data back.
// If binary is true, the data would be encoded in BINARY format.
// If more is true, a flag bit would be set to indicate there are more
// serverStatus, a flag bit represents server information.
// fetchSize, the desired number of rows to be fetched each time when client uses cursor.
// resultsets, it's used to support the MULTI_RESULTS capability in mysql protocol.
func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary bool, more bool) (runErr error) {
func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16, fetchSize int) (runErr error) {
defer func() {
terror.Call(rs.Close)
// close ResultSet when cursor doesn't exist
if !mysql.HasCursorExistsFlag(serverStatus) {
terror.Call(rs.Close)
}
r := recover()
if r == nil {
return
Expand All @@ -929,14 +936,19 @@ func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary b
buf = buf[:stackSize]
log.Errorf("query: %s:\n%s", cc.lastCmd, buf)
}()
err := cc.writeChunks(ctx, rs, binary, more)
var err error
if mysql.HasCursorExistsFlag(serverStatus) {
err = cc.writeChunksWithFetchSize(ctx, rs, serverStatus, fetchSize)
} else {
err = cc.writeChunks(ctx, rs, binary, serverStatus)
}
if err != nil {
return errors.Trace(err)
}
return errors.Trace(cc.flush())
}

func (cc *clientConn) writeColumnInfo(columns []*ColumnInfo) error {
func (cc *clientConn) writeColumnInfo(columns []*ColumnInfo, serverStatus uint16) error {
data := make([]byte, 4, 1024)
data = dumpLengthEncodedInt(data, uint64(len(columns)))
if err := cc.writePacket(data); err != nil {
Expand All @@ -949,16 +961,16 @@ func (cc *clientConn) writeColumnInfo(columns []*ColumnInfo) error {
return errors.Trace(err)
}
}
if err := cc.writeEOF(false); err != nil {
if err := cc.writeEOF(serverStatus); err != nil {
return errors.Trace(err)
}
return nil
}

// writeChunks writes data from a Chunk, which filled data by a ResultSet, into a connection.
// binary specifies the way to dump data. It throws any error while dumping data.
// more will be passed into writeEOF.
func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool, more bool) error {
// serverStatus, a flag bit represents server information
func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16) error {
data := make([]byte, 4, 1024)
chk := rs.NewChunk()
gotColumnInfo := false
Expand All @@ -972,7 +984,7 @@ func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool
// We need to call Next before we get columns.
// Otherwise, we will get incorrect columns info.
columns := rs.Columns()
err = cc.writeColumnInfo(columns)
err = cc.writeColumnInfo(columns, serverStatus)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -997,12 +1009,71 @@ func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool
}
}
}
return errors.Trace(cc.writeEOF(more))
return errors.Trace(cc.writeEOF(serverStatus))
}

// writeChunksWithFetchSize writes data from a Chunk, which filled data by a ResultSet, into a connection.
// binary specifies the way to dump data. It throws any error while dumping data.
// serverStatus, a flag bit represents server information.
// fetchSize, the desired number of rows to be fetched each time when client uses cursor.
func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet, serverStatus uint16, fetchSize int) error {
fetchedRows := rs.GetFetchedRows()

// if fetchedRows is not enough, getting data from recordSet.
for len(fetchedRows) < fetchSize {
chk := rs.NewChunk()
// Here server.tidbResultSet implements Next method.
err := rs.Next(ctx, chk)
if err != nil {
return errors.Trace(err)
}
rowCount := chk.NumRows()
if rowCount == 0 {
break
}
// filling fetchedRows with chunk
for i := 0; i < rowCount; i++ {
fetchedRows = append(fetchedRows, chk.GetRow(i))
}
}

// tell the client COM_STMT_FETCH has finished by setting proper serverStatus,
// and close ResultSet.
if len(fetchedRows) == 0 {
serverStatus |= mysql.ServerStatusLastRowSend
terror.Call(rs.Close)
return errors.Trace(cc.writeEOF(serverStatus))
}

// construct the rows sent to the client according to fetchSize.
var curRows []chunk.Row
if fetchSize < len(fetchedRows) {
curRows = fetchedRows[:fetchSize]
fetchedRows = fetchedRows[fetchSize:]
} else {
curRows = fetchedRows[:]
fetchedRows = fetchedRows[:0]
}
rs.StoreFetchedRows(fetchedRows)

data := make([]byte, 4, 1024)
var err error
for _, row := range curRows {
data = data[0:4]
data, err = dumpBinaryRow(data, rs.Columns(), row)
if err != nil {
return errors.Trace(err)
}
if err = cc.writePacket(data); err != nil {
return errors.Trace(err)
}
}
return errors.Trace(cc.writeEOF(serverStatus))
}

func (cc *clientConn) writeMultiResultset(ctx context.Context, rss []ResultSet, binary bool) error {
for _, rs := range rss {
if err := cc.writeResultset(ctx, rs, binary, true); err != nil {
if err := cc.writeResultset(ctx, rs, binary, mysql.ServerMoreResultsExists, 0); err != nil {
return errors.Trace(err)
}
}
Expand Down
75 changes: 69 additions & 6 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (cc *clientConn) handleStmtPrepare(sql string) error {
}
}

if err := cc.writeEOF(false); err != nil {
if err := cc.writeEOF(0); err != nil {
return errors.Trace(err)
}
}
Expand All @@ -95,7 +95,7 @@ func (cc *clientConn) handleStmtPrepare(sql string) error {
}
}

if err := cc.writeEOF(false); err != nil {
if err := cc.writeEOF(0); err != nil {
return errors.Trace(err)
}

Expand All @@ -119,8 +119,20 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e

flag := data[pos]
pos++
// Now we only support CURSOR_TYPE_NO_CURSOR flag.
if flag != 0 {
// Please refer to https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
// The client indicates that it wants to use cursor by setting this flag.
// 0x00 CURSOR_TYPE_NO_CURSOR
// 0x01 CURSOR_TYPE_READ_ONLY
// 0x02 CURSOR_TYPE_FOR_UPDATE
// 0x04 CURSOR_TYPE_SCROLLABLE
// Now we only support forward-only, read-only cursor.
var useCursor bool
switch flag {
case 0:
useCursor = false
case 1:
useCursor = true
default:
return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag %d", flag)
}

Expand Down Expand Up @@ -172,7 +184,58 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
return errors.Trace(cc.writeOK())
}

return errors.Trace(cc.writeResultset(ctx, rs, true, false))
// if the client wants to use cursor
// we should hold the ResultSet in PreparedStatement for next stmt_fetch, and only send back ColumnInfo.
// Tell the client cursor exists in server by setting proper serverStatus.
if useCursor {
stmt.StoreResultSet(rs)
err = cc.writeColumnInfo(rs.Columns(), mysql.ServerStatusCursorExists)
if err != nil {
return errors.Trace(err)
}
// explicitly flush columnInfo to client.
return errors.Trace(cc.flush())
}
return errors.Trace(cc.writeResultset(ctx, rs, true, 0, 0))
}

// maxFetchSize constants
const (
maxFetchSize = 1024
)

func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) {

stmtID, fetchSize, err := parseStmtFetchCmd(data)
if err != nil {
return err
}

stmt := cc.ctx.GetStatement(int(stmtID))
if stmt == nil {
return mysql.NewErr(mysql.ErrUnknownStmtHandler,
strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch")
}
rs := stmt.GetResultSet()
if rs == nil {
return mysql.NewErr(mysql.ErrUnknownStmtHandler,
strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch_rs")
}

return errors.Trace(cc.writeResultset(ctx, rs, true, mysql.ServerStatusCursorExists, int(fetchSize)))
}

func parseStmtFetchCmd(data []byte) (uint32, uint32, error) {
if len(data) != 8 {
return 0, 0, mysql.ErrMalformPacket
}
// Please refer to https://dev.mysql.com/doc/internals/en/com-stmt-fetch.html
stmtID := binary.LittleEndian.Uint32(data[0:4])
fetchSize := binary.LittleEndian.Uint32(data[4:8])
if fetchSize > maxFetchSize {
fetchSize = maxFetchSize
}
return stmtID, fetchSize, nil
}

func parseStmtArgs(args []interface{}, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte) (err error) {
Expand Down Expand Up @@ -480,7 +543,7 @@ func (cc *clientConn) handleSetOption(data []byte) (err error) {
default:
return mysql.ErrMalformPacket
}
if err = cc.writeEOF(false); err != nil {
if err = cc.writeEOF(0); err != nil {
return errors.Trace(err)
}

Expand Down
23 changes: 23 additions & 0 deletions server/conn_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,26 @@ func (ts ConnTestSuite) TestParseStmtArgs(c *C) {
c.Assert(tt.args.args[0], Equals, tt.expect)
}
}

func (ts ConnTestSuite) TestParseStmtFetchCmd(c *C) {
tests := []struct {
arg []byte
stmtID uint32
fetchSize uint32
err error
}{
{[]byte{3, 0, 0, 0, 50, 0, 0, 0}, 3, 50, nil},
{[]byte{5, 0, 0, 0, 232, 3, 0, 0}, 5, 1000, nil},
{[]byte{5, 0, 0, 0, 0, 8, 0, 0}, 5, maxFetchSize, nil},
{[]byte{5, 0, 0}, 0, 0, mysql.ErrMalformPacket},
{[]byte{1, 0, 0, 0, 3, 2, 0, 0, 3, 5, 6}, 0, 0, mysql.ErrMalformPacket},
{[]byte{}, 0, 0, mysql.ErrMalformPacket},
}

for _, t := range tests {
stmtID, fetchSize, err := parseStmtFetchCmd([]byte(t.arg))
c.Assert(stmtID, Equals, t.stmtID)
c.Assert(fetchSize, Equals, t.fetchSize)
c.Assert(err, Equals, t.err)
}
}
Loading

0 comments on commit 1fbcc10

Please sign in to comment.