diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml new file mode 100644 index 00000000..8c3a92cb --- /dev/null +++ b/.github/workflows/pr-validation.yml @@ -0,0 +1,33 @@ +name: pr-validation + +on: + pull_request: + branches: + - main + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + go: ['1.16','1.17', '1.18'] + sqlImage: ['2017-latest','2019-latest'] + steps: + - uses: actions/checkout@v2 + - name: Setup go + uses: actions/setup-go@v2 + with: + go-version: '${{ matrix.go }}' + - name: Run tests against Linux SQL + run: | + go version + go get -d + export SQLCMDPASSWORD=$(date +%s|sha256sum|base64|head -c 32) + export SQLCMDUSER=sa + export SQLUSER=sa + export SQLPASSWORD=$SQLCMDPASSWORD + export DATABASE=test + docker run -m 2GB -e ACCEPT_EULA=1 -d --name sqlserver -p:1433:1433 -e SA_PASSWORD=$SQLCMDPASSWORD mcr.microsoft.com/mssql/server:${{ matrix.sqlImage }} + sleep 10 + sqlcmd -Q "CREATE DATABASE test" + go test -race -cpu 4 ./... diff --git a/mssql.go b/mssql.go index fbd44d1f..d9a52112 100644 --- a/mssql.go +++ b/mssql.go @@ -1074,7 +1074,6 @@ type Rowsq struct { stmt *Stmt cols []columnStruct reader *tokenProcessor - nextCols []columnStruct cancel func() requestDone bool inResultSet bool @@ -1102,8 +1101,11 @@ func (rc *Rowsq) Close() error { } } -// data/sql calls Columns during the app's call to Next +// ProcessSingleResponse queues MsgNext for every columns token. +// data/sql calls Columns during the app's call to Next. func (rc *Rowsq) Columns() (res []string) { + // r.cols is nil if the first query in a batch is a SELECT or similar query that returns a rowset. + // if will be non-nil for subsequent queries where NextResultSet() has populated it if rc.cols == nil { scan: for { @@ -1145,6 +1147,10 @@ func (rc *Rowsq) Next(dest []driver.Value) error { if tok == nil { return io.EOF } else { + switch tokdata := tok.(type) { + case doneInProcStruct: + tok = (doneStruct)(tokdata) + } switch tokdata := tok.(type) { case []interface{}: for i := range dest { @@ -1172,9 +1178,11 @@ func (rc *Rowsq) Next(dest []driver.Value) error { if rc.reader.outs.returnStatus != nil { *rc.reader.outs.returnStatus = tokdata } + case ServerError: + rc.requestDone = true + return tokdata } } - } else { return rc.stmt.c.checkBadConn(rc.reader.ctx, err, false) } @@ -1187,7 +1195,7 @@ func (rc *Rowsq) HasNextResultSet() bool { return !rc.requestDone } -// Scans to the next set of columns in the stream +// Scans to the end of the current statement being processed // Note that the caller may not have read all the rows in the prior set func (rc *Rowsq) NextResultSet() error { if rc.requestDone { @@ -1195,7 +1203,6 @@ func (rc *Rowsq) NextResultSet() error { } scan: for { - // we should have a columns token in the channel if we aren't at the end tok, err := rc.reader.nextToken() if rc.reader.sess.logFlags&logDebug != 0 { rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("NextResultSet() token type:%v", reflect.TypeOf(tok))) @@ -1208,23 +1215,42 @@ scan: return io.EOF } switch tokdata := tok.(type) { + case doneInProcStruct: + tok = (doneStruct)(tokdata) + } + // ProcessSingleResponse queues a MsgNextResult for every "done" and "server error" token + // The only tokens to consume after a "done" should be "done", "server error", or "columns" + switch tokdata := tok.(type) { case []columnStruct: - rc.nextCols = tokdata + rc.cols = tokdata rc.inResultSet = true break scan case doneStruct: if tokdata.Status&doneMore == 0 { - rc.nextCols = nil rc.requestDone = true - break scan } + if tokdata.isError() { + e := rc.stmt.c.checkBadConn(rc.reader.ctx, tokdata.getError(), false) + switch e.(type) { + case Error: + // Ignore non-fatal server errors. Fatal errors are of type ServerError + default: + return e + } + } + rc.inResultSet = false + rc.cols = nil + break scan + case ReturnStatus: + if rc.reader.outs.returnStatus != nil { + *rc.reader.outs.returnStatus = tokdata + } + case ServerError: + rc.requestDone = true + return tokdata } } - rc.cols = rc.nextCols - rc.nextCols = nil - if rc.cols == nil { - return io.EOF - } + return nil } diff --git a/queries_go19_test.go b/queries_go19_test.go index 12371094..88affa41 100644 --- a/queries_go19_test.go +++ b/queries_go19_test.go @@ -1,3 +1,4 @@ +//go:build go1.9 // +build go1.9 package mssql @@ -1126,17 +1127,19 @@ func TestMessageQueue(t *testing.T) { msgs := []interface{}{ sqlexp.MsgNotice{Message: "msg1"}, + sqlexp.MsgNextResultSet{}, sqlexp.MsgNext{}, sqlexp.MsgRowsAffected{Count: 1}, + sqlexp.MsgNextResultSet{}, sqlexp.MsgNotice{Message: "msg2"}, sqlexp.MsgNextResultSet{}, + sqlexp.MsgNextResultSet{}, } i := 0 - rsCount := 0 for active { msg := retmsg.Message(ctx) if i >= len(msgs) { - t.Fatalf("Got extra message:%+v", msg) + t.Fatalf("Got extra message:%+v", reflect.TypeOf(msg)) } t.Log(reflect.TypeOf(msg)) if reflect.TypeOf(msgs[i]) != reflect.TypeOf(msg) { @@ -1147,10 +1150,6 @@ func TestMessageQueue(t *testing.T) { t.Log(m.Message) case sqlexp.MsgNextResultSet: active = rows.NextResultSet() - if active { - t.Fatal("NextResultSet returned true") - } - rsCount++ case sqlexp.MsgNext: if !rows.Next() { t.Fatal("rows.Next() returned false") @@ -1237,7 +1236,8 @@ select getdate() PRINT N'This is a message' select 199 RAISERROR (N'Testing!' , 11, 1) -select 300 +declare @d int = 300 +select @d ` func testMixedQuery(conn *sql.DB, b testing.TB) (msgs, errs, results, rowcounts int) { @@ -1368,3 +1368,135 @@ func TestCancelWithNoResults(t *testing.T) { t.Fatalf("Unexpected error: %v", r.Err()) } } + +const DropSprocWithCursor = `IF EXISTS (SELECT * FROM sys.objects WHERE object_id = OBJECT_ID(N'[dbo].[TestSqlCmd]') AND type in (N'P', N'PC')) +DROP PROCEDURE [dbo].[TestSqlCmd] +` + +// This query generates half a dozen tokenDoneInProc tokens which fill the channel if the app isn't scanning Rowsq +const CreateSprocWithCursor = ` +CREATE PROCEDURE [dbo].[TestSqlCmd] +AS +BEGIN + DECLARE @tmp int; + DECLARE Server_Cursor CURSOR FOR + SELECT 1 UNION SELECT 2 + OPEN Server_Cursor; + FETCH NEXT FROM Server_Cursor INTO @tmp; + WHILE @@FETCH_STATUS = 0 + BEGIN + PRINT @tmp + FETCH NEXT FROM Server_Cursor INTO @tmp; + END; + CLOSE Server_Cursor; + DEALLOCATE Server_Cursor; +END +` + +func TestSprocWithCursorNoResult(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + _, e := conn.Exec(DropSprocWithCursor) + if e != nil { + t.Fatalf("Unable to drop test sproc: %v", e) + } + _, e = conn.Exec(CreateSprocWithCursor) + if e != nil { + t.Fatalf("Unable to create test sproc: %v", e) + } + defer conn.Exec(DropSprocWithCursor) + latency, _ := getLatency(t) + ctx, cancel := context.WithTimeout(context.Background(), latency+500*time.Millisecond) + defer cancel() + retmsg := &sqlexp.ReturnMessage{} + // Use a sproc instead of the cursor loop directly to cover the different code path in token.go + r, err := conn.QueryContext(ctx, `exec [dbo].[TestSqlCmd]`, retmsg) + if err != nil { + t.Fatal(err.Error()) + } + defer r.Close() + active := true + rsCount := 0 + msgCount := 0 + for active { + msg := retmsg.Message(ctx) + t.Logf("Got a message: %v", reflect.TypeOf(msg)) + switch m := msg.(type) { + case sqlexp.MsgNext: + t.Fatalf("Got a MsgNext from a query with no rows") + case sqlexp.MsgError: + t.Fatalf("Got an error: %s", m.Error.Error()) + case sqlexp.MsgNotice: + msgCount++ + case sqlexp.MsgNextResultSet: + if active = r.NextResultSet(); active { + rsCount++ + } + } + } + if r.Err() != nil { + t.Fatalf("Got an error: %v", r.Err()) + } + if rsCount != 13 { + t.Fatalf("Unexpected record set count: %v", rsCount) + } + if msgCount != 2 { + t.Fatalf("Unexpected message count: %v", msgCount) + } +} + +func TestErrorAsLastResult(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + latency, _ := getLatency(t) + ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond) + defer cancel() + retmsg := &sqlexp.ReturnMessage{} + // Use a sproc instead of the cursor loop directly to cover the different code path in token.go + r, err := conn.QueryContext(ctx, + ` + Print N'message' + select 1 + raiserror(N'Error!', 16, 1)`, + retmsg) + if err != nil { + t.Fatal(err.Error()) + } + defer r.Close() + active := true + d := 0 + err = nil + for active { + msg := retmsg.Message(ctx) + t.Logf("Got a message: %s", reflect.TypeOf(msg)) + switch m := msg.(type) { + case sqlexp.MsgNext: + if !r.Next() { + t.Fatalf("Next returned false") + } + r.Scan(&d) + if r.Next() { + t.Fatal("Second Next returned true") + } + case sqlexp.MsgError: + err = m.Error + case sqlexp.MsgNextResultSet: + active = r.NextResultSet() + } + } + if err == nil { + t.Fatal("Should have gotten an error message") + } else { + switch e := err.(type) { + case Error: + if e.Message != "Error!" || e.Class != 16 { + t.Fatalf("Got the wrong mssql error %v", e) + } + default: + t.Fatalf("Got an unexpected error %v", e) + } + } +} diff --git a/token.go b/token.go index 43039d3d..80e20c78 100644 --- a/token.go +++ b/token.go @@ -49,6 +49,20 @@ const ( doneSrvError = 0x100 ) +// CurCmd values in done (undocumented) +const ( + cmdSelect = 0xc1 + // cmdInsert = 0xc3 + // cmdDelete = 0xc4 + // cmdUpdate = 0xc5 + // cmdAbort = 0xd2 + // cmdBeginXaxt = 0xd4 + // cmdEndXact = 0xd5 + // cmdBulkInsert = 0xf0 + // cmdOpenCursor = 0x20 + // cmdMerge = 0x117 +) + // ENVCHANGE types // http://msdn.microsoft.com/en-us/library/dd303449.aspx const ( @@ -645,7 +659,6 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) { } func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs outputs) { - firstResult := true defer func() { if err := recover(); err != nil { if sess.logFlags&logErrors != 0 { @@ -655,7 +668,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS } close(ch) }() - + colsReceived := false packet_type, err := sess.buf.BeginRead() if err != nil { if sess.logFlags&logErrors != 0 { @@ -697,18 +710,26 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS done := parseDoneInProc(sess.buf) ch <- done - if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { - sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d rows affected)", done.RowCount)) + if done.Status&doneCount != 0 { + if sess.logFlags&logRows != 0 { + sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d rows affected)", done.RowCount)) + } - if outs.msgq != nil { + if (colsReceived || done.CurCmd != cmdSelect) && outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)}) } } + if outs.msgq != nil { + // For now we ignore ctx->Done errors that ReturnMessageEnqueue might return + // It's not clear how to handle them correctly here, and data/sql seems + // to set Rows.Err correctly when ctx expires already + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } + colsReceived = false if done.Status&doneMore == 0 { + // Rows marks the request as done when seeing this done token. We queue another result set message + // so the app calls NextResultSet again which will return false. if outs.msgq != nil { - // For now we ignore ctx->Done errors that ReturnMessageEnqueue might return - // It's not clear how to handle them correctly here, and data/sql seems - // to set Rows.Err correctly when ctx expires already _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } return @@ -729,16 +750,24 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS } return } - if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { - sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d row(s) affected)", done.RowCount)) - } ch <- done if done.Status&doneCount != 0 { - if outs.msgq != nil { + if sess.logFlags&logRows != 0 { + sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d row(s) affected)", done.RowCount)) + } + + if (colsReceived || done.CurCmd != cmdSelect) && outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)}) } + + } + colsReceived = false + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } if done.Status&doneMore == 0 { + // Rows marks the request as done when seeing this done token. We queue another result set message + // so the app calls NextResultSet again which will return false. if outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } @@ -747,14 +776,10 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS case tokenColMetadata: columns = parseColMetadata72(sess.buf) ch <- columns - + colsReceived = true if outs.msgq != nil { - if !firstResult { - _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) - } _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNext{}) } - firstResult = false case tokenRow: row := make([]interface{}, len(columns))