From 4d88edfee2da058167b2af0388c56d9aa8979dcc Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 5 Nov 2024 13:25:08 -0500 Subject: [PATCH] GODRIVER-2388 implementation --- internal/driverutil/operation.go | 1 + .../client_side_encryption_test.go | 2 +- internal/integration/crud_prose_test.go | 553 ++++++++++++++ internal/integration/csot_prose_test.go | 54 ++ .../unified/client_operation_execution.go | 329 ++++++++ internal/integration/unified/error.go | 65 +- internal/integration/unified/operation.go | 2 + .../integration/unified/schema_version.go | 2 +- mongo/bulk_write.go | 4 +- mongo/client.go | 102 ++- mongo/client_bulk_write.go | 706 ++++++++++++++++++ mongo/client_bulk_write_models.go | 359 +++++++++ mongo/client_bulk_write_test.go | 66 ++ mongo/errors.go | 50 ++ mongo/options/clientbulkwriteoptions.go | 136 ++++ mongo/results.go | 45 ++ x/mongo/driver/batch_cursor.go | 35 +- x/mongo/driver/batches.go | 126 +++- x/mongo/driver/batches_test.go | 186 ++--- x/mongo/driver/operation.go | 433 ++++++----- x/mongo/driver/operation/abort_transaction.go | 5 +- x/mongo/driver/operation/aggregate.go | 10 +- x/mongo/driver/operation/command.go | 10 +- .../driver/operation/commit_transaction.go | 5 +- x/mongo/driver/operation/count.go | 4 +- x/mongo/driver/operation/create.go | 2 +- x/mongo/driver/operation/create_indexes.go | 4 +- .../driver/operation/create_search_indexes.go | 4 +- x/mongo/driver/operation/delete.go | 4 +- x/mongo/driver/operation/distinct.go | 4 +- x/mongo/driver/operation/drop_collection.go | 4 +- x/mongo/driver/operation/drop_indexes.go | 4 +- x/mongo/driver/operation/drop_search_index.go | 4 +- x/mongo/driver/operation/end_sessions.go | 5 +- x/mongo/driver/operation/find.go | 9 +- x/mongo/driver/operation/find_and_modify.go | 4 +- x/mongo/driver/operation/hello.go | 8 +- x/mongo/driver/operation/insert.go | 4 +- x/mongo/driver/operation/list_collections.go | 9 +- .../{listDatabases.go => list_databases.go} | 4 +- x/mongo/driver/operation/list_indexes.go | 10 +- x/mongo/driver/operation/update.go | 4 +- .../driver/operation/update_search_index.go | 4 +- x/mongo/driver/operation_exhaust.go | 5 +- x/mongo/driver/session/client_session.go | 2 - x/mongo/driver/wiremessage/wiremessage.go | 34 + 46 files changed, 2982 insertions(+), 440 deletions(-) create mode 100644 mongo/client_bulk_write.go create mode 100644 mongo/client_bulk_write_models.go create mode 100644 mongo/client_bulk_write_test.go create mode 100644 mongo/options/clientbulkwriteoptions.go rename x/mongo/driver/operation/{listDatabases.go => list_databases.go} (98%) diff --git a/internal/driverutil/operation.go b/internal/driverutil/operation.go index 32704312ff..e37cba5903 100644 --- a/internal/driverutil/operation.go +++ b/internal/driverutil/operation.go @@ -28,4 +28,5 @@ const ( ListIndexesOp = "listIndexes" // ListIndexesOp is the name for listing indexes ListDatabasesOp = "listDatabases" // ListDatabasesOp is the name for listing databases UpdateOp = "update" // UpdateOp is the name for updating + BulkWriteOp = "bulkWrite" // BulkWriteOp is the name for client-level bulk write ) diff --git a/internal/integration/client_side_encryption_test.go b/internal/integration/client_side_encryption_test.go index 1111a71021..cd49e2bf00 100644 --- a/internal/integration/client_side_encryption_test.go +++ b/internal/integration/client_side_encryption_test.go @@ -398,7 +398,7 @@ func TestClientSideEncryptionCustomCrypt(t *testing.T) { "expected 0 calls to DecryptExplicit, got %v", cc.numDecryptExplicitCalls) assert.Equal(mt, cc.numCloseCalls, 0, "expected 0 calls to Close, got %v", cc.numCloseCalls) - assert.Equal(mt, cc.numBypassAutoEncryptionCalls, 2, + assert.Equal(mt, cc.numBypassAutoEncryptionCalls, 1, "expected 2 calls to BypassAutoEncryption, got %v", cc.numBypassAutoEncryptionCalls) }) } diff --git a/internal/integration/crud_prose_test.go b/internal/integration/crud_prose_test.go index e77f4b553c..ed7ef35cfe 100644 --- a/internal/integration/crud_prose_test.go +++ b/internal/integration/crud_prose_test.go @@ -10,15 +10,21 @@ import ( "bytes" "context" "errors" + "os" + "strings" "testing" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/failpoint" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" ) func TestWriteErrorsWithLabels(t *testing.T) { @@ -409,3 +415,550 @@ func TestErrorsCodeNamePropagated(t *testing.T) { assert.Equal(mt, expectedCodeName, wce.Name, "expected code name %q, got %q", expectedCodeName, wce.Name) }) } + +func TestClientBulkWrite(t *testing.T) { + mtOpts := mtest.NewOptions().MinServerVersion("8.0").AtlasDataLake(false).ClientType(mtest.Pinned) + mt := mtest.New(t, mtOpts) + + mt.Run("bulkWrite batch splits a writeModels input with greater than maxWriteBatchSize operations", func(mt *mtest.T) { + var opsCnt []int + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + if e.CommandName == "bulkWrite" { + var c struct { + Ops []bson.D + } + err := bson.Unmarshal(e.Command, &c) + require.NoError(mt, err) + opsCnt = append(opsCnt, len(c.Ops)) + } + }, + } + mt.ResetClient(options.Client().SetMonitor(monitor)) + var hello struct { + MaxWriteBatchSize int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + models := &mongo.ClientWriteModels{} + numModels := hello.MaxWriteBatchSize + 1 + for i := 0; i < numModels; i++ { + models. + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, + }) + } + result, err := mt.Client.BulkWrite(context.Background(), models) + require.NoError(mt, err, "BulkWrite error: %v", err) + assert.Equal(mt, numModels, int(result.InsertedCount), "expected InsertedCount: %d, got %d", numModels, result.InsertedCount) + require.Len(mt, opsCnt, 2, "expected %d bulkWrite commands, got: %d", 2, len(opsCnt)) + assert.Equal(mt, numModels-1, opsCnt[0], "expected %d firstEvent.command.ops, got: %d", numModels-1, opsCnt[0]) + assert.Equal(mt, 1, opsCnt[1], "expected %d secondEvent.command.ops, got: %d", 1, opsCnt[1]) + }) + + mt.Run("bulkWrite batch splits when an ops payload exceeds maxMessageSizeBytes", func(mt *mtest.T) { + var opsCnt []int + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + if e.CommandName == "bulkWrite" { + var c struct { + Ops []bson.D + } + err := bson.Unmarshal(e.Command, &c) + require.NoError(mt, err) + opsCnt = append(opsCnt, len(c.Ops)) + } + }, + } + mt.ResetClient(options.Client().SetMonitor(monitor)) + var hello struct { + MaxBsonObjectSize int + MaxMessageSizeBytes int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + models := &mongo.ClientWriteModels{} + numModels := hello.MaxMessageSizeBytes/hello.MaxBsonObjectSize + 1 + for i := 0; i < numModels; i++ { + models. + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize-500)}}, + }) + } + result, err := mt.Client.BulkWrite(context.Background(), models) + require.NoError(mt, err, "BulkWrite error: %v", err) + assert.Equal(mt, numModels, int(result.InsertedCount), "expected InsertedCount: %d, got: %d", numModels, result.InsertedCount) + require.Len(mt, opsCnt, 2, "expected %d bulkWrite commands, got: %d", 2, len(opsCnt)) + assert.Equal(mt, numModels-1, opsCnt[0], "expected %d firstEvent.command.ops, got: %d", numModels-1, opsCnt[0]) + assert.Equal(mt, 1, opsCnt[1], "expected %d secondEvent.command.ops, got: %d", 1, opsCnt[1]) + }) + + mt.Run("bulkWrite collects WriteConcernErrors across batches", func(mt *mtest.T) { + var eventCnt int + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + if e.CommandName == "bulkWrite" { + eventCnt++ + } + }, + } + mt.ResetClient(options.Client().SetRetryWrites(false).SetMonitor(monitor)) + var hello struct { + MaxWriteBatchSize int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + + mt.SetFailPoint(failpoint.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: failpoint.Mode{ + Times: 2, + }, + Data: failpoint.Data{ + FailCommands: []string{"bulkWrite"}, + WriteConcernError: &failpoint.WriteConcernError{ + Code: 91, + Errmsg: "Replication is being shut down", + }, + }, + }) + + models := &mongo.ClientWriteModels{} + numModels := hello.MaxWriteBatchSize + 1 + for i := 0; i < numModels; i++ { + models. + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, + }) + } + _, err = mt.Client.BulkWrite(context.Background(), models) + require.Error(mt, err, "expected a BulkWrite error") + bwe, ok := err.(mongo.ClientBulkWriteException) + require.True(mt, ok, "expected a BulkWriteException, got %T: %v", err, err) + assert.Len(mt, bwe.WriteConcernErrors, 2, "expected %d writeConcernErrors, got: %d", 2, len(bwe.WriteConcernErrors)) + require.NotNil(mt, bwe.PartialResult) + assert.Equal(mt, numModels, int(bwe.PartialResult.InsertedCount), + "expected InsertedCount: %d, got: %d", numModels, bwe.PartialResult.InsertedCount) + require.Equal(mt, 2, eventCnt, "expected %d bulkWrite commands, got: %d", 2, eventCnt) + }) + + mt.Run("bulkWrite handles individual WriteErrors across batches", func(mt *mtest.T) { + var eventCnt int + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + if e.CommandName == "bulkWrite" { + eventCnt++ + } + }, + } + + mt.ResetClient(options.Client()) + var hello struct { + MaxWriteBatchSize int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + + coll := mt.CreateCollection(mtest.Collection{DB: "db", Name: "coll"}, false) + err = coll.Drop(context.Background()) + require.NoError(mt, err, "Drop error: %v", err) + _, err = coll.InsertOne(context.Background(), bson.D{{"_id", 1}}) + require.NoError(mt, err, "InsertOne error: %v", err) + + models := &mongo.ClientWriteModels{} + numModels := hello.MaxWriteBatchSize + 1 + for i := 0; i < numModels; i++ { + models. + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"_id", 1}}, + }) + } + + mt.Run("unordered", func(mt *mtest.T) { + eventCnt = 0 + mt.ResetClient(options.Client().SetMonitor(monitor)) + _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false)) + require.Error(mt, err, "expected a BulkWrite error") + bwe, ok := err.(mongo.ClientBulkWriteException) + require.True(mt, ok, "expected a BulkWriteException, got %T: %v", err, err) + assert.Len(mt, bwe.WriteErrors, numModels, "expected %d writeErrors, got %d", numModels, len(bwe.WriteErrors)) + require.Equal(mt, 2, eventCnt, "expected %d bulkWrite commands, got: %d", 2, eventCnt) + }) + mt.Run("ordered", func(mt *mtest.T) { + eventCnt = 0 + mt.ResetClient(options.Client().SetMonitor(monitor)) + _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(true)) + require.Error(mt, err, "expected a BulkWrite error") + bwe, ok := err.(mongo.ClientBulkWriteException) + require.True(mt, ok, "expected a BulkWriteException, got %T: %v", err, err) + assert.Len(mt, bwe.WriteErrors, 1, "expected %d writeErrors, got: %d", 1, len(bwe.WriteErrors)) + require.Equal(mt, 1, eventCnt, "expected %d bulkWrite commands, got: %d", 1, eventCnt) + }) + }) + + mt.Run("bulkWrite handles a cursor requiring a getMore", func(mt *mtest.T) { + var getMoreCalled int + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + if e.CommandName == "getMore" { + getMoreCalled++ + } + }, + } + mt.ResetClient(options.Client().SetMonitor(monitor)) + var hello struct { + MaxBsonObjectSize int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + + coll := mt.CreateCollection(mtest.Collection{DB: "db", Name: "coll"}, false) + err = coll.Drop(context.Background()) + require.NoError(mt, err, "Drop error: %v", err) + + upsert := true + models := (&mongo.ClientWriteModels{}). + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }). + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }) + result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetVerboseResults(true)) + require.NoError(mt, err, "BulkWrite error: %v", err) + assert.Equal(mt, int64(2), result.UpsertedCount, "expected InsertedCount: %d, got: %d", 2, result.UpsertedCount) + assert.Len(mt, result.UpdateResults, 2, "expected %d UpdateResults, got: %d", 2, len(result.UpdateResults)) + assert.Equal(mt, 1, getMoreCalled, "expected %d getMore call, got: %d", 1, getMoreCalled) + }) + + mt.RunOpts("bulkWrite handles a cursor requiring a getMore within a transaction", + mtest.NewOptions().MinServerVersion("8.0").AtlasDataLake(false).ClientType(mtest.Pinned). + Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.ShardedReplicaSet), + func(mt *mtest.T) { + var getMoreCalled int + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + if e.CommandName == "getMore" { + getMoreCalled++ + } + }, + } + mt.ResetClient(options.Client().SetMonitor(monitor)) + var hello struct { + MaxBsonObjectSize int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + + coll := mt.CreateCollection(mtest.Collection{DB: "db", Name: "coll"}, false) + err = coll.Drop(context.Background()) + require.NoError(mt, err, "Drop error: %v", err) + + session, err := mt.Client.StartSession() + require.NoError(mt, err, "StartSession error: %v", err) + defer session.EndSession(context.Background()) + + upsert := true + models := (&mongo.ClientWriteModels{}). + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }). + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }) + result, err := session.WithTransaction(context.Background(), func(ctx context.Context) (interface{}, error) { + return mt.Client.BulkWrite(ctx, models, options.ClientBulkWrite().SetVerboseResults(true)) + }) + require.NoError(mt, err, "BulkWrite error: %v", err) + cbwResult, ok := result.(*mongo.ClientBulkWriteResult) + require.True(mt, ok, "expected a ClientBulkWriteResult, got %T", result) + assert.Equal(mt, int64(2), cbwResult.UpsertedCount, "expected InsertedCount: %d, got: %d", 2, cbwResult.UpsertedCount) + assert.Len(mt, cbwResult.UpdateResults, 2, "expected %d UpdateResults, got: %d", 2, len(cbwResult.UpdateResults)) + assert.Equal(mt, 1, getMoreCalled, "expected %d getMore call, got: %d", 1, getMoreCalled) + }) + + mt.Run("bulkWrite handles a getMore error", func(mt *mtest.T) { + var getMoreCalled int + var killCursorsCalled int + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + switch e.CommandName { + case "getMore": + getMoreCalled++ + case "killCursors": + killCursorsCalled++ + } + }, + } + mt.ResetClient(options.Client().SetMonitor(monitor)) + var hello struct { + MaxBsonObjectSize int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + + mt.SetFailPoint(failpoint.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: failpoint.Mode{ + Times: 1, + }, + Data: failpoint.Data{ + FailCommands: []string{"getMore"}, + ErrorCode: 8, + }, + }) + + coll := mt.CreateCollection(mtest.Collection{DB: "db", Name: "coll"}, false) + err = coll.Drop(context.Background()) + require.NoError(mt, err, "Drop error: %v", err) + + upsert := true + models := (&mongo.ClientWriteModels{}). + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }). + AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }) + _, err = mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetVerboseResults(true)) + assert.Error(mt, err, "expected a BulkWrite error") + bwe, ok := err.(mongo.ClientBulkWriteException) + require.True(mt, ok, "expected a BulkWriteException, got %T: %v", err, err) + require.NotNil(mt, bwe.TopLevelError) + assert.Equal(mt, 8, bwe.TopLevelError.Code, "expected top level error code: %d, got; %d", 8, bwe.TopLevelError.Code) + require.NotNil(mt, bwe.PartialResult) + assert.Equal(mt, int64(2), bwe.PartialResult.UpsertedCount, "expected UpsertedCount: %d, got: %d", 2, bwe.PartialResult.UpsertedCount) + assert.Len(mt, bwe.PartialResult.UpdateResults, 1, "expected %d UpdateResults, got: %d", 1, len(bwe.PartialResult.UpdateResults)) + assert.Equal(mt, 1, getMoreCalled, "expected %d getMore call, got: %d", 1, getMoreCalled) + assert.Equal(mt, 1, killCursorsCalled, "expected %d killCursors call, got: %d", 1, killCursorsCalled) + }) + + mt.Run("bulkWrite returns error for unacknowledged too-large insert", func(mt *mtest.T) { + mt.ResetClient(options.Client()) + var hello struct { + MaxBsonObjectSize int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + mt.Run("insert", func(mt *mtest.T) { + models := (&mongo.ClientWriteModels{}). + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize)}}, + }) + _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) + require.EqualError(mt, err, driver.ErrDocumentTooLarge.Error()) + }) + mt.Run("replace", func(mt *mtest.T) { + models := (&mongo.ClientWriteModels{}). + AppendReplaceOne("db", "coll", &mongo.ClientReplaceOneModel{ + Filter: bson.D{}, + Replacement: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize)}}, + }) + _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) + require.EqualError(mt, err, driver.ErrDocumentTooLarge.Error()) + }) + }) + + mt.Run("bulkWrite batch splits when the addition of a new namespace exceeds the maximum message size", func(mt *mtest.T) { + type cmd struct { + Ops []bson.D + NsInfo []struct { + Ns string + } + } + var bwCmd []cmd + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + if e.CommandName == "bulkWrite" { + var c cmd + err := bson.Unmarshal(e.Command, &c) + require.NoError(mt, err, "Unmarshal error: %v", err) + bwCmd = append(bwCmd, c) + } + }, + } + mt.ResetClient(options.Client()) + var hello struct { + MaxBsonObjectSize int + MaxMessageSizeBytes int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + + newModels := func() (int, *mongo.ClientWriteModels) { + maxBsonObjectSize := hello.MaxBsonObjectSize + opsBytes := hello.MaxMessageSizeBytes - 1122 + numModels := opsBytes / maxBsonObjectSize + + models := &mongo.ClientWriteModels{} + n := numModels + for i := 0; i < n; i++ { + models. + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", maxBsonObjectSize-57)}}, + }) + } + if remainderBytes := opsBytes % maxBsonObjectSize; remainderBytes > 217 { + n++ + models. + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", remainderBytes-57)}}, + }) + } + return n, models + } + mt.Run("no batch-splitting required", func(mt *mtest.T) { + bwCmd = bwCmd[:0] + mt.ResetClient(options.Client().SetMonitor(monitor)) + + numModels, models := newModels() + models.AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, + }) + result, err := mt.Client.BulkWrite(context.Background(), models) + require.NoError(mt, err, "BulkWrite error: %v", err) + assert.Equal(mt, numModels+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", numModels+1, result.InsertedCount) + require.Len(mt, bwCmd, 1, "expected %d bulkWrite call, got: %d", 1, len(bwCmd)) + + assert.Len(mt, bwCmd[0].Ops, numModels+1, "expected %d ops, got: %d", numModels+1, len(bwCmd[0].Ops)) + require.Len(mt, bwCmd[0].NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(bwCmd[0].NsInfo)) + assert.Equal(mt, "db.coll", bwCmd[0].NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", bwCmd[0].NsInfo[0].Ns) + }) + mt.Run("batch-splitting required", func(mt *mtest.T) { + bwCmd = bwCmd[:0] + mt.ResetClient(options.Client().SetMonitor(monitor)) + + coll := strings.Repeat("c", 200) + numModels, models := newModels() + models.AppendInsertOne("db", coll, &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, + }) + result, err := mt.Client.BulkWrite(context.Background(), models) + require.NoError(mt, err, "BulkWrite error: %v", err) + assert.Equal(mt, numModels+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", numModels+1, result.InsertedCount) + require.Len(mt, bwCmd, 2, "expected %d bulkWrite calls, got: %d", 2, len(bwCmd)) + + assert.Len(mt, bwCmd[0].Ops, numModels, "expected %d ops, got: %d", numModels, len(bwCmd[0].Ops)) + require.Len(mt, bwCmd[0].NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(bwCmd[0].NsInfo)) + assert.Equal(mt, "db.coll", bwCmd[0].NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", bwCmd[0].NsInfo[0].Ns) + + assert.Len(mt, bwCmd[1].Ops, 1, "expected %d ops, got: %d", 1, len(bwCmd[1].Ops)) + require.Len(mt, bwCmd[1].NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(bwCmd[1].NsInfo)) + assert.Equal(mt, "db."+coll, bwCmd[1].NsInfo[0].Ns, "expected namespace: %s, got: %s", "db."+coll, bwCmd[1].NsInfo[0].Ns) + }) + }) + + mt.Run("bulkWrite returns an error if no operations can be added to ops", func(mt *mtest.T) { + mt.ResetClient(options.Client()) + var hello struct { + MaxMessageSizeBytes int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + mt.Run("document too large", func(mt *mtest.T) { + models := (&mongo.ClientWriteModels{}). + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", hello.MaxMessageSizeBytes)}}, + }) + _, err := mt.Client.BulkWrite(context.Background(), models) + require.EqualError(mt, err, driver.ErrDocumentTooLarge.Error()) + }) + mt.Run("namespace too large", func(mt *mtest.T) { + models := (&mongo.ClientWriteModels{}). + AppendInsertOne("db", strings.Repeat("c", hello.MaxMessageSizeBytes), &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, + }) + _, err := mt.Client.BulkWrite(context.Background(), models) + require.EqualError(mt, err, driver.ErrDocumentTooLarge.Error()) + }) + }) + + mt.Run("bulkWrite returns an error if auto-encryption is configured", func(mt *mtest.T) { + if os.Getenv("DOCKER_RUNNING") != "" { + mt.Skip("skipping test in docker environment") + } + + autoEncryptionOpts := options.AutoEncryption(). + SetKeyVaultNamespace("db.coll"). + SetKmsProviders(map[string]map[string]interface{}{ + "aws": { + "accessKeyId": "foo", + "secretAccessKey": "bar", + }, + }) + mt.ResetClient(options.Client().SetAutoEncryptionOptions(autoEncryptionOpts)) + models := (&mongo.ClientWriteModels{}). + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, + }) + _, err := mt.Client.BulkWrite(context.Background(), models) + require.ErrorContains(mt, err, "bulkWrite does not currently support automatic encryption") + }) + + mt.Run("bulkWrite with unacknowledged write concern uses w:0 for all batches", func(mt *mtest.T) { + type cmd struct { + Ops []bson.D + WriteConcern struct { + W interface{} + } + } + var bwCmd []cmd + monitor := &event.CommandMonitor{ + Started: func(_ context.Context, e *event.CommandStartedEvent) { + if e.CommandName == "bulkWrite" { + var c cmd + err := bson.Unmarshal(e.Command, &c) + require.NoError(mt, err, "Unmarshal error: %v", err) + + bwCmd = append(bwCmd, c) + } + }, + } + mt.ResetClient(options.Client().SetMonitor(monitor)) + var hello struct { + MaxBsonObjectSize int + MaxMessageSizeBytes int + } + err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + + coll := mt.CreateCollection(mtest.Collection{DB: "db", Name: "coll"}, false) + err = coll.Drop(context.Background()) + require.NoError(mt, err, "Drop error: %v", err) + + numModels := hello.MaxMessageSizeBytes/hello.MaxBsonObjectSize + 1 + models := &mongo.ClientWriteModels{} + for i := 0; i < numModels; i++ { + models. + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize-500)}}, + }) + } + result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) + require.NoError(mt, err, "BulkWrite error: %v", err) + assert.Nil(mt, result, "expected a nil result, got: %v", result) + require.Len(mt, bwCmd, 2, "expected %d bulkWrite calls, got: %d", 2, len(bwCmd)) + + assert.Len(mt, bwCmd[0].Ops, numModels-1, "expected %d ops, got: %d", numModels-1, len(bwCmd[0].Ops)) + assert.Equal(mt, int32(0), bwCmd[0].WriteConcern.W, "expected writeConcern: %d, got: %v", 0, bwCmd[0].WriteConcern.W) + + assert.Len(mt, bwCmd[1].Ops, 1, "expected %d ops, got: %d", 1, len(bwCmd[1].Ops)) + assert.Equal(mt, int32(0), bwCmd[1].WriteConcern.W, "expected writeConcern: %d, got: %v", 0, bwCmd[1].WriteConcern.W) + + n, err := coll.CountDocuments(context.Background(), bson.D{}) + require.NoError(mt, err, "CountDocuments error: %v", err) + assert.Equal(mt, numModels, int(n), "expected %d documents, got: %d", numModels, n) + }) +} diff --git a/internal/integration/csot_prose_test.go b/internal/integration/csot_prose_test.go index 960891ee5c..ec944e9a9b 100644 --- a/internal/integration/csot_prose_test.go +++ b/internal/integration/csot_prose_test.go @@ -177,6 +177,60 @@ func TestCSOTProse(t *testing.T) { "expected ping to fail within 150ms") }) }) + + mt.RunOpts("11. multi-batch bulkWrites", mtest.NewOptions().MinServerVersion("8.0"). + AtlasDataLake(false).Topologies(mtest.Single), func(mt *mtest.T) { + coll := mt.CreateCollection(mtest.Collection{DB: "db", Name: "coll"}, false) + err := coll.Drop(context.Background()) + require.NoError(mt, err, "Drop error: %v", err) + + mt.SetFailPoint(failpoint.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: failpoint.Mode{ + Times: 2, + }, + Data: failpoint.Data{ + FailCommands: []string{"bulkWrite"}, + BlockConnection: true, + BlockTimeMS: 1010, + }, + }) + + var hello struct { + MaxBsonObjectSize int + MaxMessageSizeBytes int + } + err = mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) + require.NoError(mt, err, "Hello error: %v", err) + + models := &mongo.ClientWriteModels{} + n := hello.MaxMessageSizeBytes/hello.MaxBsonObjectSize + 1 + for i := 0; i < n; i++ { + models. + AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize-500)}}, + }) + } + + var cnt int + cm := &event.CommandMonitor{ + Started: func(_ context.Context, evt *event.CommandStartedEvent) { + if evt.CommandName == "bulkWrite" { + cnt++ + } + }, + } + cliOptions := options.Client(). + SetTimeout(2 * time.Second). + SetMonitor(cm). + ApplyURI(mtest.ClusterURI()) + integtest.AddTestServerAPIVersion(cliOptions) + cli, err := mongo.Connect(cliOptions) + require.NoError(mt, err, "Connect error: %v", err) + _, err = cli.BulkWrite(context.Background(), models) + assert.ErrorContains(mt, err, "context deadline exceeded", "expected a timeout error, got: %v", err) + assert.Equal(mt, 2, cnt, "expected bulkWrite calls: %d, got: %d", 2, cnt) + }) } func TestCSOTProse_GridFS(t *testing.T) { diff --git a/internal/integration/unified/client_operation_execution.go b/internal/integration/unified/client_operation_execution.go index 6758af9fc0..dc0c14b309 100644 --- a/internal/integration/unified/client_operation_execution.go +++ b/internal/integration/unified/client_operation_execution.go @@ -8,7 +8,10 @@ package unified import ( "context" + "errors" "fmt" + "strconv" + "strings" "time" "go.mongodb.org/mongo-driver/v2/bson" @@ -163,3 +166,329 @@ func executeListDatabases(ctx context.Context, operation *operation, nameOnly bo Build() return newDocumentResult(raw, nil), nil } + +func executeClientBulkWrite(ctx context.Context, operation *operation) (*operationResult, error) { + client, err := entities(ctx).client(operation.Object) + if err != nil { + return nil, err + } + + wirteModels := &mongo.ClientWriteModels{} + opts := options.ClientBulkWrite() + + elems, err := operation.Arguments.Elements() + if err != nil { + return nil, err + } + for _, elem := range elems { + key := elem.Key() + val := elem.Value() + + switch key { + case "models": + models, err := val.Array().Values() + if err != nil { + return nil, err + } + for _, m := range models { + model := m.Document().Index(0) + err = appendClientBulkWriteModel(model.Key(), model.Value().Document(), wirteModels) + if err != nil { + return nil, err + } + } + case "bypassDocumentValidation": + opts.SetBypassDocumentValidation(val.Boolean()) + case "comment": + opts.SetComment(val) + case "let": + opts.SetLet(val.Document()) + case "ordered": + opts.SetOrdered(val.Boolean()) + case "verboseResults": + opts.SetVerboseResults(val.Boolean()) + case "writeConcern": + var wc writeConcern + err := bson.Unmarshal(val.Value, &wc) + if err != nil { + return nil, err + } + c, err := wc.toWriteConcernOption() + if err != nil { + return nil, err + } + opts.SetWriteConcern(c) + default: + return nil, fmt.Errorf("unrecognized bulkWrite option %q", key) + } + } + + res, err := client.BulkWrite(ctx, wirteModels, opts) + if res == nil { + var bwe mongo.ClientBulkWriteException + if !errors.As(err, &bwe) || bwe.PartialResult == nil { + return newDocumentResult(emptyCoreDocument, err), nil + } + res = bwe.PartialResult + } + rawBuilder := bsoncore.NewDocumentBuilder(). + AppendInt64("deletedCount", res.DeletedCount). + AppendInt64("insertedCount", res.InsertedCount). + AppendInt64("matchedCount", res.MatchedCount). + AppendInt64("modifiedCount", res.ModifiedCount). + AppendInt64("upsertedCount", res.UpsertedCount) + + var resBuilder *bsoncore.DocumentBuilder + + resBuilder = bsoncore.NewDocumentBuilder() + for k, v := range res.DeleteResults { + resBuilder.AppendDocument(strconv.Itoa(k), + bsoncore.NewDocumentBuilder(). + AppendInt64("deletedCount", v.DeletedCount). + Build(), + ) + } + rawBuilder.AppendDocument("deleteResults", resBuilder.Build()) + + resBuilder = bsoncore.NewDocumentBuilder() + for k, v := range res.InsertResults { + t, d, err := bson.MarshalValue(v.InsertedID) + if err != nil { + return nil, err + } + resBuilder.AppendDocument(strconv.Itoa(k), + bsoncore.NewDocumentBuilder(). + AppendValue("insertedId", bsoncore.Value{Type: bsoncore.Type(t), Data: d}). + Build(), + ) + } + rawBuilder.AppendDocument("insertResults", resBuilder.Build()) + + resBuilder = bsoncore.NewDocumentBuilder() + for k, v := range res.UpdateResults { + b := bsoncore.NewDocumentBuilder(). + AppendInt64("matchedCount", v.MatchedCount). + AppendInt64("modifiedCount", v.ModifiedCount) + if v.UpsertedID != nil { + t, d, err := bson.MarshalValue(v.UpsertedID) + if err != nil { + return nil, err + } + b.AppendValue("upsertedId", bsoncore.Value{Type: bsoncore.Type(t), Data: d}) + } + resBuilder.AppendDocument(strconv.Itoa(k), b.Build()) + } + rawBuilder.AppendDocument("updateResults", resBuilder.Build()) + + return newDocumentResult(rawBuilder.Build(), err), nil +} + +func appendClientBulkWriteModel(key string, value bson.Raw, model *mongo.ClientWriteModels) error { + switch key { + case "insertOne": + namespace, m, err := createClientInsertOneModel(value) + if err != nil { + return err + } + ns := strings.SplitN(namespace, ".", 2) + model.AppendInsertOne(ns[0], ns[1], m) + case "updateOne": + namespace, m, err := createClientUpdateOneModel(value) + if err != nil { + return err + } + ns := strings.SplitN(namespace, ".", 2) + model.AppendUpdateOne(ns[0], ns[1], m) + case "updateMany": + namespace, m, err := createClientUpdateManyModel(value) + if err != nil { + return err + } + ns := strings.SplitN(namespace, ".", 2) + model.AppendUpdateMany(ns[0], ns[1], m) + case "replaceOne": + namespace, m, err := createClientReplaceOneModel(value) + if err != nil { + return err + } + ns := strings.SplitN(namespace, ".", 2) + model.AppendReplaceOne(ns[0], ns[1], m) + case "deleteOne": + namespace, m, err := createClientDeleteOneModel(value) + if err != nil { + return err + } + ns := strings.SplitN(namespace, ".", 2) + model.AppendDeleteOne(ns[0], ns[1], m) + case "deleteMany": + namespace, m, err := createClientDeleteManyModel(value) + if err != nil { + return err + } + ns := strings.SplitN(namespace, ".", 2) + model.AppendDeleteMany(ns[0], ns[1], m) + } + return nil +} + +func createClientInsertOneModel(value bson.Raw) (string, *mongo.ClientInsertOneModel, error) { + var v struct { + Namespace string + Document bson.Raw + } + err := bson.Unmarshal(value, &v) + if err != nil { + return "", nil, err + } + return v.Namespace, &mongo.ClientInsertOneModel{ + Document: v.Document, + }, nil +} + +func createClientUpdateOneModel(value bson.Raw) (string, *mongo.ClientUpdateOneModel, error) { + var v struct { + Namespace string + Filter bson.Raw + Update interface{} + ArrayFilters []interface{} + Collation *options.Collation + Hint *bson.RawValue + Upsert *bool + } + err := bson.Unmarshal(value, &v) + if err != nil { + return "", nil, err + } + var hint interface{} + if v.Hint != nil { + hint, err = createHint(*v.Hint) + if err != nil { + return "", nil, err + } + } + model := &mongo.ClientUpdateOneModel{ + Filter: v.Filter, + Update: v.Update, + Collation: v.Collation, + Hint: hint, + Upsert: v.Upsert, + } + if len(v.ArrayFilters) > 0 { + model.ArrayFilters = v.ArrayFilters + } + return v.Namespace, model, nil + +} + +func createClientUpdateManyModel(value bson.Raw) (string, *mongo.ClientUpdateManyModel, error) { + var v struct { + Namespace string + Filter bson.Raw + Update interface{} + ArrayFilters []interface{} + Collation *options.Collation + Hint *bson.RawValue + Upsert *bool + } + err := bson.Unmarshal(value, &v) + if err != nil { + return "", nil, err + } + var hint interface{} + if v.Hint != nil { + hint, err = createHint(*v.Hint) + if err != nil { + return "", nil, err + } + } + model := &mongo.ClientUpdateManyModel{ + Filter: v.Filter, + Update: v.Update, + Collation: v.Collation, + Hint: hint, + Upsert: v.Upsert, + } + if len(v.ArrayFilters) > 0 { + model.ArrayFilters = v.ArrayFilters + } + return v.Namespace, model, nil +} + +func createClientReplaceOneModel(value bson.Raw) (string, *mongo.ClientReplaceOneModel, error) { + var v struct { + Namespace string + Filter bson.Raw + Replacement bson.Raw + Collation *options.Collation + Hint *bson.RawValue + Upsert *bool + } + err := bson.Unmarshal(value, &v) + if err != nil { + return "", nil, err + } + var hint interface{} + if v.Hint != nil { + hint, err = createHint(*v.Hint) + if err != nil { + return "", nil, err + } + } + return v.Namespace, &mongo.ClientReplaceOneModel{ + Filter: v.Filter, + Replacement: v.Replacement, + Collation: v.Collation, + Hint: hint, + Upsert: v.Upsert, + }, nil +} + +func createClientDeleteOneModel(value bson.Raw) (string, *mongo.ClientDeleteOneModel, error) { + var v struct { + Namespace string + Filter bson.Raw + Collation *options.Collation + Hint *bson.RawValue + } + err := bson.Unmarshal(value, &v) + if err != nil { + return "", nil, err + } + var hint interface{} + if v.Hint != nil { + hint, err = createHint(*v.Hint) + if err != nil { + return "", nil, err + } + } + return v.Namespace, &mongo.ClientDeleteOneModel{ + Filter: v.Filter, + Collation: v.Collation, + Hint: hint, + }, nil +} + +func createClientDeleteManyModel(value bson.Raw) (string, *mongo.ClientDeleteManyModel, error) { + var v struct { + Namespace string + Filter bson.Raw + Collation *options.Collation + Hint *bson.RawValue + } + err := bson.Unmarshal(value, &v) + if err != nil { + return "", nil, err + } + var hint interface{} + if v.Hint != nil { + hint, err = createHint(*v.Hint) + if err != nil { + return "", nil, err + } + } + return v.Namespace, &mongo.ClientDeleteManyModel{ + Filter: v.Filter, + Collation: v.Collation, + Hint: hint, + }, nil +} diff --git a/internal/integration/unified/error.go b/internal/integration/unified/error.go index e68689414f..f7adf5095d 100644 --- a/internal/integration/unified/error.go +++ b/internal/integration/unified/error.go @@ -8,6 +8,7 @@ package unified import ( "context" + "errors" "fmt" "strings" @@ -18,15 +19,22 @@ import ( // expectedError represents an error that is expected to occur during a test. This type ignores the "isError" field in // test files because it is always true if it is specified, so the runner can simply assert that an error occurred. type expectedError struct { - IsClientError *bool `bson:"isClientError"` - IsTimeoutError *bool `bson:"isTimeoutError"` - ErrorSubstring *string `bson:"errorContains"` - Code *int32 `bson:"errorCode"` - CodeName *string `bson:"errorCodeName"` - IncludedLabels []string `bson:"errorLabelsContain"` - OmittedLabels []string `bson:"errorLabelsOmit"` - ExpectedResult *bson.RawValue `bson:"expectResult"` - ErrorResponse *bson.Raw `bson:"errorResponse"` + IsClientError *bool `bson:"isClientError"` + IsTimeoutError *bool `bson:"isTimeoutError"` + ErrorSubstring *string `bson:"errorContains"` + Code *int32 `bson:"errorCode"` + CodeName *string `bson:"errorCodeName"` + IncludedLabels []string `bson:"errorLabelsContain"` + OmittedLabels []string `bson:"errorLabelsOmit"` + ExpectedResult *bson.RawValue `bson:"expectResult"` + ErrorResponse *bson.Raw `bson:"errorResponse"` + WriteErrors map[int]clientBulkWriteException `bson:"writeErrors"` + WriteConcernErrors []clientBulkWriteException `bson:"writeConcernErrors"` +} + +type clientBulkWriteException struct { + Code *int `bson:"code"` + Message *string `bson:"message"` } // verifyOperationError compares the expected error to the actual operation result. If the expected parameter is nil, @@ -133,6 +141,40 @@ func verifyOperationError(ctx context.Context, expected *expectedError, result * return fmt.Errorf("error response comparison error: %w", err) } } + if expected.WriteErrors != nil { + var exception mongo.ClientBulkWriteException + if !errors.As(result.Err, &exception) { + return fmt.Errorf("expected a ClientBulkWriteException, got %T: %v", result.Err, result.Err) + } + if len(expected.WriteErrors) != len(exception.WriteErrors) { + return fmt.Errorf("expected errors: %v, got: %v", expected.WriteErrors, exception.WriteErrors) + } + for k, e := range expected.WriteErrors { + if e.Code != nil && *e.Code != exception.WriteErrors[k].Code { + return fmt.Errorf("expected errors: %v, got: %v", expected.WriteConcernErrors, exception.WriteConcernErrors) + } + if e.Message != nil && *e.Message != exception.WriteErrors[k].Message { + return fmt.Errorf("expected errors: %v, got: %v", expected.WriteConcernErrors, exception.WriteConcernErrors) + } + } + } + if expected.WriteConcernErrors != nil { + var exception mongo.ClientBulkWriteException + if !errors.As(result.Err, &exception) { + return fmt.Errorf("expected a ClientBulkWriteException, got %T: %v", result.Err, result.Err) + } + if len(expected.WriteConcernErrors) != len(exception.WriteConcernErrors) { + return fmt.Errorf("expected errors: %v, got: %v", expected.WriteConcernErrors, exception.WriteConcernErrors) + } + for i, e := range expected.WriteConcernErrors { + if e.Code != nil && *e.Code != exception.WriteConcernErrors[i].Code { + return fmt.Errorf("expected errors: %v, got: %v", expected.WriteConcernErrors, exception.WriteConcernErrors) + } + if e.Message != nil && *e.Message != exception.WriteConcernErrors[i].Message { + return fmt.Errorf("expected errors: %v, got: %v", expected.WriteConcernErrors, exception.WriteConcernErrors) + } + } + } return nil } @@ -175,6 +217,11 @@ func extractErrorDetails(err error) (errorDetails, bool) { details.raw = we.Raw } details.labels = converted.Labels + case mongo.ClientBulkWriteException: + if converted.TopLevelError != nil { + details.raw = converted.TopLevelError.Raw + details.codes = append(details.codes, int32(converted.TopLevelError.Code)) + } default: return errorDetails{}, false } diff --git a/internal/integration/unified/operation.go b/internal/integration/unified/operation.go index f74fc9361e..a604c83d60 100644 --- a/internal/integration/unified/operation.go +++ b/internal/integration/unified/operation.go @@ -135,6 +135,8 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat return executeListDatabases(ctx, op, false) case "listDatabaseNames": return executeListDatabases(ctx, op, true) + case "clientBulkWrite": + return executeClientBulkWrite(ctx, op) // Database operations case "createCollection": diff --git a/internal/integration/unified/schema_version.go b/internal/integration/unified/schema_version.go index 1d2f99fb96..4bff29d7c5 100644 --- a/internal/integration/unified/schema_version.go +++ b/internal/integration/unified/schema_version.go @@ -16,7 +16,7 @@ import ( var ( supportedSchemaVersions = map[int]string{ - 1: "1.17", + 1: "1.21", } ) diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index d042a89c37..5838f141ae 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -167,8 +167,7 @@ func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWr func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (operation.InsertResult, error) { docs := make([]bsoncore.Document, len(batch.models)) - var i int - for _, model := range batch.models { + for i, model := range batch.models { converted := model.(*InsertOneModel) doc, err := marshal(converted.Document, bw.collection.bsonOpts, bw.collection.registry) if err != nil { @@ -180,7 +179,6 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera } docs[i] = doc - i++ } op := operation.NewInsert(docs...). diff --git a/mongo/client.go b/mongo/client.go index f3ed5d9097..6bc3290b2a 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -75,14 +75,15 @@ type Client struct { logger *logger.Logger // in-use encryption fields - keyVaultClientFLE *Client - keyVaultCollFLE *Collection - mongocryptdFLE *mongocryptdClient - cryptFLE driver.Crypt - metadataClientFLE *Client - internalClientFLE *Client - encryptedFieldsMap map[string]interface{} - authenticator driver.Authenticator + isAutoEncryptionSet bool + keyVaultClientFLE *Client + keyVaultCollFLE *Collection + mongocryptdFLE *mongocryptdClient + cryptFLE driver.Crypt + metadataClientFLE *Client + internalClientFLE *Client + encryptedFieldsMap map[string]interface{} + authenticator driver.Authenticator } // Connect creates a new Client and then initializes it using the Connect method. @@ -196,6 +197,7 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) { } // AutoEncryptionOptions if args.AutoEncryptionOptions != nil { + client.isAutoEncryptionSet = true if err := client.configureAutoEncryption(args); err != nil { return nil, err } @@ -438,10 +440,6 @@ func (c *Client) StartSession(opts ...options.Lister[options.SessionOptions]) (* return nil, replaceErrors(err) } - // Writes are not retryable on standalones, so let operation determine whether to retry - sess.RetryWrite = false - sess.RetryRead = c.retryReads - return &Session{ clientSession: sess, client: c, @@ -894,6 +892,86 @@ func (c *Client) createBaseCursorOptions() driver.CursorOptions { } } +// BulkWrite performs a client-level bulk write operation. +func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels, + opts ...options.Lister[options.ClientBulkWriteOptions]) (*ClientBulkWriteResult, error) { + // TODO: Remove once DRIVERS-2888 is implemented. + if c.isAutoEncryptionSet { + return nil, errors.New("bulkWrite does not currently support automatic encryption") + } + bwo, err := mongoutil.NewOptions(opts...) + if err != nil { + return nil, err + } + + if ctx == nil { + ctx = context.Background() + } + + sess := sessionFromContext(ctx) + if sess == nil && c.sessionPool != nil { + sess = session.NewImplicitClientSession(c.sessionPool, c.id) + defer sess.EndSession() + } + + err = c.validSession(sess) + if err != nil { + return nil, err + } + + transactionRunning := sess.TransactionRunning() + wc := c.writeConcern + if transactionRunning { + wc = nil + } + if bwo.WriteConcern != nil { + if transactionRunning { + return nil, errors.New("cannot set write concern after starting a transaction") + } + wc = bwo.WriteConcern + } + acknowledged := wc.Acknowledged() + if !acknowledged { + if bwo.Ordered == nil || *bwo.Ordered { + return nil, errors.New("cannot request unacknowledged write concern and ordered writes") + } + sess = nil + } + + writeSelector := &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.Write{}, + &serverselector.Latency{Latency: c.localThreshold}, + }, + } + selector := makePinnedSelector(sess, writeSelector) + + op := clientBulkWrite{ + models: models.models, + ordered: bwo.Ordered, + bypassDocumentValidation: bwo.BypassDocumentValidation, + comment: bwo.Comment, + let: bwo.Let, + session: sess, + client: c, + selector: selector, + writeConcern: wc, + } + if bwo.VerboseResults == nil || !(*bwo.VerboseResults) { + op.errorsOnly = true + } else if !acknowledged { + return nil, errors.New("cannot request unacknowledged write concern and verbose results") + } + if err = op.execute(ctx); err != nil { + return nil, replaceErrors(err) + } + var results *ClientBulkWriteResult + if acknowledged { + results = &op.result + } + return results, nil +} + // newLogger will use the LoggerOptions to create an internal logger and publish // messages using a LogSink. func newLogger(opts options.Lister[options.LoggerOptions]) (*logger.Logger, error) { diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go new file mode 100644 index 0000000000..13b3085a52 --- /dev/null +++ b/mongo/client_bulk_write.go @@ -0,0 +1,706 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mongo + +import ( + "bytes" + "context" + "errors" + "io" + "strconv" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/driverutil" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/session" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" +) + +const ( + database = "admin" +) + +type clientBulkWrite struct { + models []clientWriteModel + errorsOnly bool + ordered *bool + bypassDocumentValidation *bool + comment interface{} + let interface{} + session *session.Client + client *Client + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + + result ClientBulkWriteResult +} + +func (bw *clientBulkWrite) execute(ctx context.Context) error { + if len(bw.models) == 0 { + return ErrEmptySlice + } + for _, m := range bw.models { + if m.model == nil { + return ErrNilDocument + } + } + batches := &modelBatches{ + session: bw.session, + client: bw.client, + ordered: bw.ordered == nil || *bw.ordered, + models: bw.models, + result: &bw.result, + retryMode: driver.RetryOnce, + } + err := driver.Operation{ + CommandFn: bw.newCommand(), + ProcessResponseFn: batches.processResponse, + Client: bw.session, + Clock: bw.client.clock, + RetryMode: &batches.retryMode, + Type: driver.Write, + Batches: batches, + CommandMonitor: bw.client.monitor, + Database: database, + Deployment: bw.client.deployment, + Selector: bw.selector, + WriteConcern: bw.writeConcern, + Crypt: bw.client.cryptFLE, + ServerAPI: bw.client.serverAPI, + Timeout: bw.client.timeout, + Logger: bw.client.logger, + Authenticator: bw.client.authenticator, + Name: driverutil.BulkWriteOp, + }.Execute(ctx) + var exception *ClientBulkWriteException + switch tt := err.(type) { + case CommandError: + exception = &ClientBulkWriteException{ + TopLevelError: &WriteError{ + Code: int(tt.Code), + Message: tt.Message, + Raw: tt.Raw, + }, + } + default: + if errors.Is(err, driver.ErrUnacknowledgedWrite) { + err = nil + } + } + if len(batches.writeConcernErrors) > 0 || len(batches.writeErrors) > 0 { + if exception == nil { + exception = new(ClientBulkWriteException) + } + exception.WriteConcernErrors = batches.writeConcernErrors + exception.WriteErrors = batches.writeErrors + } + if exception != nil { + var hasSuccess bool + if batches.ordered { + _, ok := batches.writeErrors[0] + hasSuccess = !ok + } else { + hasSuccess = len(batches.writeErrors) < len(bw.models) + } + if hasSuccess { + exception.PartialResult = batches.result + } + return *exception + } + return err +} + +func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer) ([]byte, error) { + return func(dst []byte, desc description.SelectedServer) ([]byte, error) { + dst = bsoncore.AppendInt32Element(dst, "bulkWrite", 1) + + dst = bsoncore.AppendBooleanElement(dst, "errorsOnly", bw.errorsOnly) + if bw.bypassDocumentValidation != nil && (desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 4)) { + dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *bw.bypassDocumentValidation) + } + if bw.comment != nil { + comment, err := marshalValue(bw.comment, bw.client.bsonOpts, bw.client.registry) + if err != nil { + return nil, err + } + dst = bsoncore.AppendValueElement(dst, "comment", comment) + } + dst = bsoncore.AppendBooleanElement(dst, "ordered", bw.ordered == nil || *bw.ordered) + if bw.let != nil { + let, err := marshal(bw.let, bw.client.bsonOpts, bw.client.registry) + if err != nil { + return nil, err + } + dst = bsoncore.AppendDocumentElement(dst, "let", let) + } + return dst, nil + } +} + +type cursorInfo struct { + Ok bool + Idx int32 + Code *int32 + Errmsg *string + ErrInfo bson.Raw + N int32 + NModified *int32 + Upserted *struct { + ID interface{} `bson:"_id"` + } +} + +func (cur *cursorInfo) extractError() *WriteError { + if cur.Ok { + return nil + } + err := &WriteError{ + Index: int(cur.Idx), + Details: cur.ErrInfo, + } + if cur.Code != nil { + err.Code = int(*cur.Code) + } + if cur.Errmsg != nil { + err.Message = *cur.Errmsg + } + return err +} + +type modelBatches struct { + session *session.Client + client *Client + + ordered bool + models []clientWriteModel + + offset int + + retryMode driver.RetryMode // RetryNone by default + cursorHandlers []func(*cursorInfo, bson.Raw) bool + newIDMap map[int]interface{} + + result *ClientBulkWriteResult + writeConcernErrors []WriteConcernError + writeErrors map[int]WriteError +} + +func (mb *modelBatches) IsOrdered() *bool { + return &mb.ordered +} + +func (mb *modelBatches) AdvanceBatches(n int) { + mb.offset += n + if mb.offset > len(mb.models) { + mb.offset = len(mb.models) + } +} + +func (mb *modelBatches) Size() int { + if mb.offset > len(mb.models) { + return 0 + } + return len(mb.models) - mb.offset +} + +func (mb *modelBatches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { + fn := functionSet{ + appendStart: func(dst []byte, identifier string) (int32, []byte) { + var idx int32 + dst = wiremessage.AppendMsgSectionType(dst, wiremessage.DocumentSequence) + idx, dst = bsoncore.ReserveLength(dst) + dst = append(dst, identifier...) + dst = append(dst, 0x00) + return idx, dst + }, + appendDocument: func(dst []byte, _ string, doc []byte) []byte { + dst = append(dst, doc...) + return dst + }, + updateLength: func(dst []byte, idx, length int32) []byte { + dst = bsoncore.UpdateLength(dst, idx, length) + return dst + }, + } + return mb.appendBatches(fn, dst, maxCount, maxDocSize, totalSize) +} + +func (mb *modelBatches) AppendBatchArray(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { + fn := functionSet{ + appendStart: bsoncore.AppendArrayElementStart, + appendDocument: bsoncore.AppendDocumentElement, + updateLength: func(dst []byte, idx, _ int32) []byte { + dst, _ = bsoncore.AppendArrayEnd(dst, idx) + return dst + }, + } + return mb.appendBatches(fn, dst, maxCount, maxDocSize, totalSize) +} + +type functionSet struct { + appendStart func([]byte, string) (int32, []byte) + appendDocument func([]byte, string, []byte) []byte + updateLength func([]byte, int32, int32) []byte +} + +func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) { + if mb.Size() == 0 { + return 0, dst, io.EOF + } + + mb.cursorHandlers = mb.cursorHandlers[:0] + mb.newIDMap = make(map[int]interface{}) + + nsMap := make(map[string]int) + getNsIndex := func(namespace string) (int, bool) { + v, ok := nsMap[namespace] + if ok { + return v, ok + } + nsIdx := len(nsMap) + nsMap[namespace] = nsIdx + return nsIdx, ok + } + + canRetry := true + checkSize := true + + l := len(dst) + + opsIdx, dst := fn.appendStart(dst, "ops") + nsIdx, nsDst := fn.appendStart(nil, "nsInfo") + + totalSize -= 1000 + size := len(dst) + len(nsDst) + var n int + for i := mb.offset; i < len(mb.models); i++ { + if n == maxCount { + break + } + + ns := mb.models[i].namespace + nsIdx, exists := getNsIndex(ns) + + var doc bsoncore.Document + var err error + switch model := mb.models[i].model.(type) { + case *ClientInsertOneModel: + checkSize = false + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendInsertResult) + var id interface{} + id, doc, err = (&clientInsertDoc{ + namespace: nsIdx, + document: model.Document, + sizeLimit: maxDocSize, + }).marshal(mb.client.bsonOpts, mb.client.registry) + if err != nil { + break + } + mb.newIDMap[i] = id + case *ClientUpdateOneModel: + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendUpdateResult) + doc, err = (&clientUpdateDoc{ + namespace: nsIdx, + filter: model.Filter, + update: model.Update, + hint: model.Hint, + arrayFilters: model.ArrayFilters, + collation: model.Collation, + upsert: model.Upsert, + multi: false, + checkDollarKey: true, + }).marshal(mb.client.bsonOpts, mb.client.registry) + case *ClientUpdateManyModel: + canRetry = false + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendUpdateResult) + doc, err = (&clientUpdateDoc{ + namespace: nsIdx, + filter: model.Filter, + update: model.Update, + hint: model.Hint, + arrayFilters: model.ArrayFilters, + collation: model.Collation, + upsert: model.Upsert, + multi: true, + checkDollarKey: true, + }).marshal(mb.client.bsonOpts, mb.client.registry) + case *ClientReplaceOneModel: + checkSize = false + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendUpdateResult) + doc, err = (&clientUpdateDoc{ + namespace: nsIdx, + filter: model.Filter, + update: model.Replacement, + hint: model.Hint, + arrayFilters: nil, + collation: model.Collation, + upsert: model.Upsert, + multi: false, + checkDollarKey: false, + sizeLimit: maxDocSize, + }).marshal(mb.client.bsonOpts, mb.client.registry) + case *ClientDeleteOneModel: + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendDeleteResult) + doc, err = (&clientDeleteDoc{ + namespace: nsIdx, + filter: model.Filter, + collation: model.Collation, + hint: model.Hint, + multi: false, + }).marshal(mb.client.bsonOpts, mb.client.registry) + case *ClientDeleteManyModel: + canRetry = false + mb.cursorHandlers = append(mb.cursorHandlers, mb.appendDeleteResult) + doc, err = (&clientDeleteDoc{ + namespace: nsIdx, + filter: model.Filter, + collation: model.Collation, + hint: model.Hint, + multi: true, + }).marshal(mb.client.bsonOpts, mb.client.registry) + default: + mb.cursorHandlers = append(mb.cursorHandlers, nil) + } + if err != nil { + return 0, nil, err + } + length := len(doc) + if maxDocSize > 0 && length > maxDocSize+16*1024 { + return 0, nil, driver.ErrDocumentTooLarge + } + if !exists { + length += len(ns) + } + size += length + if size >= totalSize { + break + } + + dst = fn.appendDocument(dst, strconv.Itoa(n), doc) + if !exists { + idx, doc := bsoncore.AppendDocumentStart(nil) + doc = bsoncore.AppendStringElement(doc, "ns", ns) + doc, _ = bsoncore.AppendDocumentEnd(doc, idx) + nsDst = fn.appendDocument(nsDst, strconv.Itoa(n), doc) + } + n++ + } + if n == 0 { + return 0, dst[:l], nil + } + + dst = fn.updateLength(dst, opsIdx, int32(len(dst[opsIdx:]))) + nsDst = fn.updateLength(nsDst, nsIdx, int32(len(nsDst[nsIdx:]))) + dst = append(dst, nsDst...) + if checkSize && maxDocSize > 0 && len(dst)-l > maxDocSize+16*1024 { + return 0, nil, driver.ErrDocumentTooLarge + } + + mb.retryMode = driver.RetryNone + if mb.client.retryWrites && canRetry { + mb.retryMode = driver.RetryOnce + } + return n, dst, nil +} + +func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + var writeCmdErr driver.WriteCommandError + if errors.As(info.Error, &writeCmdErr) && writeCmdErr.WriteConcernError != nil { + wce := convertDriverWriteConcernError(writeCmdErr.WriteConcernError) + if wce != nil { + mb.writeConcernErrors = append(mb.writeConcernErrors, *wce) + } + } + if len(resp) == 0 { + return nil + } + var res struct { + Ok bool + Cursor bsoncore.Document + NDeleted int32 + NInserted int32 + NMatched int32 + NModified int32 + NUpserted int32 + NErrors int32 + Code int32 + Errmsg string + } + dec := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(resp))) + dec.SetRegistry(mb.client.registry) + err := dec.Decode(&res) + if err != nil { + return err + } + if !res.Ok { + return ClientBulkWriteException{ + TopLevelError: &WriteError{ + Code: int(res.Code), + Message: res.Errmsg, + Raw: bson.Raw(resp), + }, + WriteConcernErrors: mb.writeConcernErrors, + WriteErrors: mb.writeErrors, + PartialResult: mb.result, + } + } + + mb.result.DeletedCount += int64(res.NDeleted) + mb.result.InsertedCount += int64(res.NInserted) + mb.result.MatchedCount += int64(res.NMatched) + mb.result.ModifiedCount += int64(res.NModified) + mb.result.UpsertedCount += int64(res.NUpserted) + + var cursorRes driver.CursorResponse + cursorRes, err = driver.NewCursorResponse(res.Cursor, info) + if err != nil { + return err + } + var bCursor *driver.BatchCursor + bCursor, err = driver.NewBatchCursor(cursorRes, mb.session, mb.client.clock, + driver.CursorOptions{ + CommandMonitor: mb.client.monitor, + Crypt: mb.client.cryptFLE, + ServerAPI: mb.client.serverAPI, + MarshalValueEncoderFn: newEncoderFn(mb.client.bsonOpts, mb.client.registry), + }, + ) + if err != nil { + return err + } + var cursor *Cursor + cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry) + if err != nil { + return err + } + defer cursor.Close(ctx) + + ok := true + for cursor.Next(ctx) { + var cur cursorInfo + err = cursor.Decode(&cur) + if err != nil { + return err + } + if int(cur.Idx) >= len(mb.cursorHandlers) { + continue + } + ok = mb.cursorHandlers[int(cur.Idx)](&cur, cursor.Current) && ok + } + err = cursor.Err() + if err != nil { + return err + } + if mb.ordered && (writeCmdErr.WriteConcernError != nil || !ok || !res.Ok || res.NErrors > 0) { + return ClientBulkWriteException{ + WriteConcernErrors: mb.writeConcernErrors, + WriteErrors: mb.writeErrors, + PartialResult: mb.result, + } + } + return nil +} + +func (mb *modelBatches) appendDeleteResult(cur *cursorInfo, raw bson.Raw) bool { + idx := int(cur.Idx) + mb.offset + if err := cur.extractError(); err != nil { + err.Raw = raw + if mb.writeErrors == nil { + mb.writeErrors = make(map[int]WriteError) + } + mb.writeErrors[idx] = *err + return false + } + + if mb.result.DeleteResults == nil { + mb.result.DeleteResults = make(map[int]ClientDeleteResult) + } + mb.result.DeleteResults[idx] = ClientDeleteResult{int64(cur.N)} + + return true +} + +func (mb *modelBatches) appendInsertResult(cur *cursorInfo, raw bson.Raw) bool { + idx := int(cur.Idx) + mb.offset + if err := cur.extractError(); err != nil { + err.Raw = raw + if mb.writeErrors == nil { + mb.writeErrors = make(map[int]WriteError) + } + mb.writeErrors[idx] = *err + return false + } + + if mb.result.InsertResults == nil { + mb.result.InsertResults = make(map[int]ClientInsertResult) + } + mb.result.InsertResults[idx] = ClientInsertResult{mb.newIDMap[idx]} + + return true +} + +func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool { + idx := int(cur.Idx) + mb.offset + if err := cur.extractError(); err != nil { + err.Raw = raw + if mb.writeErrors == nil { + mb.writeErrors = make(map[int]WriteError) + } + mb.writeErrors[idx] = *err + return false + } + + if mb.result.UpdateResults == nil { + mb.result.UpdateResults = make(map[int]ClientUpdateResult) + } + result := ClientUpdateResult{ + MatchedCount: int64(cur.N), + } + if cur.NModified != nil { + result.ModifiedCount = int64(*cur.NModified) + } + if cur.Upserted != nil { + result.UpsertedID = cur.Upserted.ID + } + mb.result.UpdateResults[idx] = result + + return true +} + +type clientInsertDoc struct { + namespace int + document interface{} + + sizeLimit int +} + +func (d *clientInsertDoc) marshal(bsonOpts *options.BSONOptions, registry *bson.Registry) (interface{}, bsoncore.Document, error) { + uidx, doc := bsoncore.AppendDocumentStart(nil) + + doc = bsoncore.AppendInt32Element(doc, "insert", int32(d.namespace)) + f, err := marshal(d.document, bsonOpts, registry) + if err != nil { + return nil, nil, err + } + if d.sizeLimit > 0 && len(f) > d.sizeLimit { + return nil, nil, driver.ErrDocumentTooLarge + } + var id interface{} + f, id, err = ensureID(f, bson.NilObjectID, bsonOpts, registry) + if err != nil { + return nil, nil, err + } + doc = bsoncore.AppendDocumentElement(doc, "document", f) + doc, err = bsoncore.AppendDocumentEnd(doc, uidx) + return id, doc, err +} + +type clientUpdateDoc struct { + namespace int + filter interface{} + update interface{} + hint interface{} + arrayFilters []interface{} + collation *options.Collation + upsert *bool + multi bool + checkDollarKey bool + + sizeLimit int +} + +func (d *clientUpdateDoc) marshal(bsonOpts *options.BSONOptions, registry *bson.Registry) (bsoncore.Document, error) { + uidx, doc := bsoncore.AppendDocumentStart(nil) + + doc = bsoncore.AppendInt32Element(doc, "update", int32(d.namespace)) + + f, err := marshal(d.filter, bsonOpts, registry) + if err != nil { + return nil, err + } + doc = bsoncore.AppendDocumentElement(doc, "filter", f) + + u, err := marshalUpdateValue(d.update, bsonOpts, registry, d.checkDollarKey) + if err != nil { + return nil, err + } + if d.sizeLimit > 0 && len(u.Data) > d.sizeLimit { + return nil, driver.ErrDocumentTooLarge + } + doc = bsoncore.AppendValueElement(doc, "updateMods", u) + doc = bsoncore.AppendBooleanElement(doc, "multi", d.multi) + + if d.arrayFilters != nil { + reg := registry + arr, err := marshalValue(d.arrayFilters, bsonOpts, reg) + if err != nil { + return nil, err + } + doc = bsoncore.AppendArrayElement(doc, "arrayFilters", arr.Data) + } + + if d.collation != nil { + doc = bsoncore.AppendDocumentElement(doc, "collation", toDocument(d.collation)) + } + + if d.upsert != nil { + doc = bsoncore.AppendBooleanElement(doc, "upsert", *d.upsert) + } + + if d.hint != nil { + if isUnorderedMap(d.hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hintVal, err := marshalValue(d.hint, bsonOpts, registry) + if err != nil { + return nil, err + } + doc = bsoncore.AppendValueElement(doc, "hint", hintVal) + } + + return bsoncore.AppendDocumentEnd(doc, uidx) +} + +type clientDeleteDoc struct { + namespace int + filter interface{} + collation *options.Collation + hint interface{} + multi bool +} + +func (d *clientDeleteDoc) marshal(bsonOpts *options.BSONOptions, registry *bson.Registry) (bsoncore.Document, error) { + didx, doc := bsoncore.AppendDocumentStart(nil) + + doc = bsoncore.AppendInt32Element(doc, "delete", int32(d.namespace)) + + f, err := marshal(d.filter, bsonOpts, registry) + if err != nil { + return nil, err + } + doc = bsoncore.AppendDocumentElement(doc, "filter", f) + doc = bsoncore.AppendBooleanElement(doc, "multi", d.multi) + + if d.collation != nil { + doc = bsoncore.AppendDocumentElement(doc, "collation", toDocument(d.collation)) + } + if d.hint != nil { + if isUnorderedMap(d.hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hintVal, err := marshalValue(d.hint, bsonOpts, registry) + if err != nil { + return nil, err + } + doc = bsoncore.AppendValueElement(doc, "hint", hintVal) + } + return bsoncore.AppendDocumentEnd(doc, didx) +} diff --git a/mongo/client_bulk_write_models.go b/mongo/client_bulk_write_models.go new file mode 100644 index 0000000000..b1ed7f5eb9 --- /dev/null +++ b/mongo/client_bulk_write_models.go @@ -0,0 +1,359 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mongo + +import ( + "fmt" + + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +// ClientWriteModels is a struct that can be used in a client-level BulkWrite operation. +type ClientWriteModels struct { + models []clientWriteModel +} +type clientWriteModel struct { + namespace string + model interface{} +} + +// AppendInsertOne appends ClientInsertOneModels. +func (m *ClientWriteModels) AppendInsertOne(database, collection string, models ...*ClientInsertOneModel) *ClientWriteModels { + if m == nil { + m = &ClientWriteModels{} + } + for _, model := range models { + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) + } + return m +} + +// AppendUpdateOne appends ClientUpdateOneModels. +func (m *ClientWriteModels) AppendUpdateOne(database, collection string, models ...*ClientUpdateOneModel) *ClientWriteModels { + if m == nil { + m = &ClientWriteModels{} + } + for _, model := range models { + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) + } + return m +} + +// AppendUpdateMany appends ClientUpdateManyModels. +func (m *ClientWriteModels) AppendUpdateMany(database, collection string, models ...*ClientUpdateManyModel) *ClientWriteModels { + if m == nil { + m = &ClientWriteModels{} + } + for _, model := range models { + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) + } + return m +} + +// AppendReplaceOne appends ClientReplaceOneModels. +func (m *ClientWriteModels) AppendReplaceOne(database, collection string, models ...*ClientReplaceOneModel) *ClientWriteModels { + if m == nil { + m = &ClientWriteModels{} + } + for _, model := range models { + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) + } + return m +} + +// AppendDeleteOne appends ClientDeleteOneModels. +func (m *ClientWriteModels) AppendDeleteOne(database, collection string, models ...*ClientDeleteOneModel) *ClientWriteModels { + if m == nil { + m = &ClientWriteModels{} + } + for _, model := range models { + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) + } + return m +} + +// AppendDeleteMany appends ClientDeleteManyModels. +func (m *ClientWriteModels) AppendDeleteMany(database, collection string, models ...*ClientDeleteManyModel) *ClientWriteModels { + if m == nil { + m = &ClientWriteModels{} + } + for _, model := range models { + m.models = append(m.models, clientWriteModel{ + namespace: fmt.Sprintf("%s.%s", database, collection), + model: model, + }) + } + return m +} + +// ClientInsertOneModel is used to insert a single document in a client-level BulkWrite operation. +type ClientInsertOneModel struct { + Document interface{} +} + +// NewClientInsertOneModel creates a new ClientInsertOneModel. +func NewClientInsertOneModel() *ClientInsertOneModel { + return &ClientInsertOneModel{} +} + +// SetDocument specifies the document to be inserted. The document cannot be nil. If it does not have an _id field when +// transformed into BSON, one will be added automatically to the marshalled document. The original document will not be +// modified. +func (iom *ClientInsertOneModel) SetDocument(doc interface{}) *ClientInsertOneModel { + iom.Document = doc + return iom +} + +// ClientUpdateOneModel is used to update at most one document in a client-level BulkWrite operation. +type ClientUpdateOneModel struct { + Collation *options.Collation + Upsert *bool + Filter interface{} + Update interface{} + ArrayFilters []interface{} + Hint interface{} +} + +// NewClientUpdateOneModel creates a new ClientUpdateOneModel. +func NewClientUpdateOneModel() *ClientUpdateOneModel { + return &ClientUpdateOneModel{} +} + +// SetHint specifies the index to use for the operation. This should either be the index name as a string or the index +// specification as a document. The default value is nil, which means that no hint will be sent. +func (uom *ClientUpdateOneModel) SetHint(hint interface{}) *ClientUpdateOneModel { + uom.Hint = hint + return uom +} + +// SetFilter specifies a filter to use to select the document to update. The filter must be a document containing query +// operators. It cannot be nil. If the filter matches multiple documents, one will be selected from the matching +// documents. +func (uom *ClientUpdateOneModel) SetFilter(filter interface{}) *ClientUpdateOneModel { + uom.Filter = filter + return uom +} + +// SetUpdate specifies the modifications to be made to the selected document. The value must be a document containing +// update operators (https://www.mongodb.com/docs/manual/reference/operator/update/). It cannot be nil or empty. +func (uom *ClientUpdateOneModel) SetUpdate(update interface{}) *ClientUpdateOneModel { + uom.Update = update + return uom +} + +// SetArrayFilters specifies a set of filters to determine which elements should be modified when updating an array +// field. +func (uom *ClientUpdateOneModel) SetArrayFilters(filters []interface{}) *ClientUpdateOneModel { + uom.ArrayFilters = filters + return uom +} + +// SetCollation specifies a collation to use for string comparisons. The default is nil, meaning no collation will be +// used. +func (uom *ClientUpdateOneModel) SetCollation(collation *options.Collation) *ClientUpdateOneModel { + uom.Collation = collation + return uom +} + +// SetUpsert specifies whether or not a new document should be inserted if no document matching the filter is found. If +// an upsert is performed, the _id of the upserted document can be retrieved from the UpdateResults field of the +// ClientBulkWriteResult. +func (uom *ClientUpdateOneModel) SetUpsert(upsert bool) *ClientUpdateOneModel { + uom.Upsert = &upsert + return uom +} + +// ClientUpdateManyModel is used to update multiple documents in a client-level BulkWrite operation. +type ClientUpdateManyModel struct { + Collation *options.Collation + Upsert *bool + Filter interface{} + Update interface{} + ArrayFilters []interface{} + Hint interface{} +} + +// NewClientUpdateManyModel creates a new ClientUpdateManyModel. +func NewClientUpdateManyModel() *ClientUpdateManyModel { + return &ClientUpdateManyModel{} +} + +// SetHint specifies the index to use for the operation. This should either be the index name as a string or the index +// specification as a document. The default value is nil, which means that no hint will be sent. +func (umm *ClientUpdateManyModel) SetHint(hint interface{}) *ClientUpdateManyModel { + umm.Hint = hint + return umm +} + +// SetFilter specifies a filter to use to select documents to update. The filter must be a document containing query +// operators. It cannot be nil. +func (umm *ClientUpdateManyModel) SetFilter(filter interface{}) *ClientUpdateManyModel { + umm.Filter = filter + return umm +} + +// SetUpdate specifies the modifications to be made to the selected documents. The value must be a document containing +// update operators (https://www.mongodb.com/docs/manual/reference/operator/update/). It cannot be nil or empty. +func (umm *ClientUpdateManyModel) SetUpdate(update interface{}) *ClientUpdateManyModel { + umm.Update = update + return umm +} + +// SetArrayFilters specifies a set of filters to determine which elements should be modified when updating an array +// field. +func (umm *ClientUpdateManyModel) SetArrayFilters(filters []interface{}) *ClientUpdateManyModel { + umm.ArrayFilters = filters + return umm +} + +// SetCollation specifies a collation to use for string comparisons. The default is nil, meaning no collation will be +// used. +func (umm *ClientUpdateManyModel) SetCollation(collation *options.Collation) *ClientUpdateManyModel { + umm.Collation = collation + return umm +} + +// SetUpsert specifies whether or not a new document should be inserted if no document matching the filter is found. If +// an upsert is performed, the _id of the upserted document can be retrieved from the UpdateResults field of the +// ClientBulkWriteResult. +func (umm *ClientUpdateManyModel) SetUpsert(upsert bool) *ClientUpdateManyModel { + umm.Upsert = &upsert + return umm +} + +// ClientReplaceOneModel is used to replace at most one document in a client-level BulkWrite operation. +type ClientReplaceOneModel struct { + Collation *options.Collation + Upsert *bool + Filter interface{} + Replacement interface{} + Hint interface{} +} + +// NewClientReplaceOneModel creates a new ClientReplaceOneModel. +func NewClientReplaceOneModel() *ClientReplaceOneModel { + return &ClientReplaceOneModel{} +} + +// SetHint specifies the index to use for the operation. This should either be the index name as a string or the index +// specification as a document. The default value is nil, which means that no hint will be sent. +func (rom *ClientReplaceOneModel) SetHint(hint interface{}) *ClientReplaceOneModel { + rom.Hint = hint + return rom +} + +// SetFilter specifies a filter to use to select the document to replace. The filter must be a document containing query +// operators. It cannot be nil. If the filter matches multiple documents, one will be selected from the matching +// documents. +func (rom *ClientReplaceOneModel) SetFilter(filter interface{}) *ClientReplaceOneModel { + rom.Filter = filter + return rom +} + +// SetReplacement specifies a document that will be used to replace the selected document. It cannot be nil and cannot +// contain any update operators (https://www.mongodb.com/docs/manual/reference/operator/update/). +func (rom *ClientReplaceOneModel) SetReplacement(rep interface{}) *ClientReplaceOneModel { + rom.Replacement = rep + return rom +} + +// SetCollation specifies a collation to use for string comparisons. The default is nil, meaning no collation will be +// used. +func (rom *ClientReplaceOneModel) SetCollation(collation *options.Collation) *ClientReplaceOneModel { + rom.Collation = collation + return rom +} + +// SetUpsert specifies whether or not the replacement document should be inserted if no document matching the filter is +// found. If an upsert is performed, the _id of the upserted document can be retrieved from the UpdateResults field of the +// BulkWriteResult. +func (rom *ClientReplaceOneModel) SetUpsert(upsert bool) *ClientReplaceOneModel { + rom.Upsert = &upsert + return rom +} + +// ClientDeleteOneModel is used to delete at most one document in a client-level BulkWriteOperation. +type ClientDeleteOneModel struct { + Filter interface{} + Collation *options.Collation + Hint interface{} +} + +// NewClientDeleteOneModel creates a new ClientDeleteOneModel. +func NewClientDeleteOneModel() *ClientDeleteOneModel { + return &ClientDeleteOneModel{} +} + +// SetFilter specifies a filter to use to select the document to delete. The filter must be a document containing query +// operators. It cannot be nil. If the filter matches multiple documents, one will be selected from the matching +// documents. +func (dom *ClientDeleteOneModel) SetFilter(filter interface{}) *ClientDeleteOneModel { + dom.Filter = filter + return dom +} + +// SetCollation specifies a collation to use for string comparisons. The default is nil, meaning no collation will be +// used. +func (dom *ClientDeleteOneModel) SetCollation(collation *options.Collation) *ClientDeleteOneModel { + dom.Collation = collation + return dom +} + +// SetHint specifies the index to use for the operation. This should either be the index name as a string or the index +// specification as a document. The default value is nil, which means that no hint will be sent. +func (dom *ClientDeleteOneModel) SetHint(hint interface{}) *ClientDeleteOneModel { + dom.Hint = hint + return dom +} + +// ClientDeleteManyModel is used to delete multiple documents in a client-level BulkWrite operation. +type ClientDeleteManyModel struct { + Filter interface{} + Collation *options.Collation + Hint interface{} +} + +// NewClientDeleteManyModel creates a new ClientDeleteManyModel. +func NewClientDeleteManyModel() *ClientDeleteManyModel { + return &ClientDeleteManyModel{} +} + +// SetFilter specifies a filter to use to select documents to delete. The filter must be a document containing query +// operators. It cannot be nil. +func (dmm *ClientDeleteManyModel) SetFilter(filter interface{}) *ClientDeleteManyModel { + dmm.Filter = filter + return dmm +} + +// SetCollation specifies a collation to use for string comparisons. The default is nil, meaning no collation will be +// used. +func (dmm *ClientDeleteManyModel) SetCollation(collation *options.Collation) *ClientDeleteManyModel { + dmm.Collation = collation + return dmm +} + +// SetHint specifies the index to use for the operation. This should either be the index name as a string or the index +// specification as a document. The default value is nil, which means that no hint will be sent. +func (dmm *ClientDeleteManyModel) SetHint(hint interface{}) *ClientDeleteManyModel { + dmm.Hint = hint + return dmm +} diff --git a/mongo/client_bulk_write_test.go b/mongo/client_bulk_write_test.go new file mode 100644 index 0000000000..9d46ffccb3 --- /dev/null +++ b/mongo/client_bulk_write_test.go @@ -0,0 +1,66 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mongo + +import ( + "testing" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/require" +) + +func TestBatches(t *testing.T) { + t.Run("test Addvancing", func(t *testing.T) { + batches := &modelBatches{ + models: make([]clientWriteModel, 2), + } + batches.AdvanceBatches(3) + size := batches.Size() + assert.Equal(t, 0, size, "expected: %d, got: %d", 1, size) + }) + t.Run("test appendBatches", func(t *testing.T) { + client, err := newClient() + require.NoError(t, err, "NewClient error: %v", err) + batches := &modelBatches{ + client: client, + models: []clientWriteModel{ + {"ns0", nil}, + {"ns1", &ClientInsertOneModel{ + Document: bson.D{{"foo", 42}}, + }}, + {"ns2", &ClientReplaceOneModel{ + Filter: bson.D{{"foo", "bar"}}, + Replacement: bson.D{{"foo", "baz"}}, + }}, + {"ns1", &ClientDeleteOneModel{ + Filter: bson.D{{"qux", "quux"}}, + }}, + }, + offset: 1, + result: &ClientBulkWriteResult{}, + } + var n int + n, _, err = batches.AppendBatchSequence(nil, 4, 16_000, 16_000) + require.NoError(t, err, "AppendBatchSequence error: %v", err) + assert.Equal(t, 3, n, "expected %d appendings, got: %d", 3, n) + + _ = batches.cursorHandlers[0](&cursorInfo{Ok: true, Idx: 0}, nil) + _ = batches.cursorHandlers[1](&cursorInfo{Ok: true, Idx: 1}, nil) + _ = batches.cursorHandlers[2](&cursorInfo{Ok: true, Idx: 2}, nil) + + ins, ok := batches.result.InsertResults[1] + assert.True(t, ok, "expected an insert results") + assert.NotNil(t, ins.InsertedID, "expected an ID") + + _, ok = batches.result.UpdateResults[2] + assert.True(t, ok, "expected an insert results") + + _, ok = batches.result.DeleteResults[3] + assert.True(t, ok, "expected an insert results") + }) +} diff --git a/mongo/errors.go b/mongo/errors.go index b9f8ec8d8e..b8f6090c4f 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -612,6 +612,56 @@ func (bwe BulkWriteException) HasErrorCodeWithMessage(code int, message string) // serverError implements the ServerError interface. func (bwe BulkWriteException) serverError() {} +// ClientBulkWriteException is the error type returned by ClientBulkWrite operations. +type ClientBulkWriteException struct { + // A top-level error that occurred when attempting to communicate with the server + // or execute the bulk write. This value may not be populated if the exception was + // thrown due to errors occurring on individual writes. + TopLevelError *WriteError + + // The write concern errors that occurred. + WriteConcernErrors []WriteConcernError + + // The write errors that occurred during individual operation execution. + // This map will contain at most one entry if the bulk write was ordered. + WriteErrors map[int]WriteError + + // The results of any successful operations that were performed before the error + // was encountered. + PartialResult *ClientBulkWriteResult +} + +// Error implements the error interface. +func (bwe ClientBulkWriteException) Error() string { + causes := make([]string, 0, 4) + if bwe.TopLevelError != nil { + causes = append(causes, "top level error: "+bwe.TopLevelError.Error()) + } + if len(bwe.WriteConcernErrors) > 0 { + errs := make([]error, len(bwe.WriteConcernErrors)) + for i := 0; i < len(bwe.WriteConcernErrors); i++ { + errs[i] = bwe.WriteConcernErrors[i] + } + causes = append(causes, "write concern errors: "+joinBatchErrors(errs)) + } + if len(bwe.WriteErrors) > 0 { + errs := make([]error, 0, len(bwe.WriteErrors)) + for _, v := range bwe.WriteErrors { + errs = append(errs, v) + } + causes = append(causes, "write errors: "+joinBatchErrors(errs)) + } + if bwe.PartialResult != nil { + causes = append(causes, fmt.Sprintf("result: %v", *bwe.PartialResult)) + } + + message := "bulk write exception: " + if len(causes) == 0 { + return message + "no causes" + } + return "bulk write exception: " + strings.Join(causes, ", ") +} + // returnResult is used to determine if a function calling processWriteError should return // the result or return nil. Since the processWriteError function is used by many different // methods, both *One and *Many, we need a way to differentiate if the method should return diff --git a/mongo/options/clientbulkwriteoptions.go b/mongo/options/clientbulkwriteoptions.go new file mode 100644 index 0000000000..3d9a4cb92e --- /dev/null +++ b/mongo/options/clientbulkwriteoptions.go @@ -0,0 +1,136 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package options + +import ( + "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" +) + +// ClientBulkWriteOptions represents options that can be used to configure a client-level BulkWrite operation. +type ClientBulkWriteOptions struct { + // If true, writes executed as part of the operation will opt out of document-level validation on the server. The + // default value is false. See https://www.mongodb.com/docs/manual/core/schema-validation/ for more information + // about document validation. + BypassDocumentValidation *bool + + // A string or document that will be included in server logs, profiling logs, and currentOp queries to help trace + // the operation. The default value is nil, which means that no comment will be included in the logs. + Comment interface{} + + // If true, no writes will be executed after one fails. The default value is true. + Ordered *bool + + // Specifies parameters for all update and delete commands in the BulkWrite. This must be a document mapping + // parameter names to values. Values must be constant or closed expressions that do not reference document fields. + // Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). + Let interface{} + + // The write concern to use for this bulk write. + WriteConcern *writeconcern.WriteConcern + + // Whether detailed results for each successful operation should be included in the returned BulkWriteResult. + VerboseResults *bool +} + +// ClientBulkWriteOptionsBuilder contains options to configure client-level bulk +// write operations. +// Each option can be set through setter functions. See documentation for each +// setter function for an explanation of the option. +type ClientBulkWriteOptionsBuilder struct { + Opts []func(*ClientBulkWriteOptions) error +} + +// ClientBulkWrite creates a new *ClientBulkWriteOptions instance. +func ClientBulkWrite() *ClientBulkWriteOptionsBuilder { + opts := &ClientBulkWriteOptionsBuilder{} + opts = opts.SetOrdered(DefaultOrdered) + + return opts +} + +// List returns a list of ClientBulkWriteOptions setter functions. +func (b *ClientBulkWriteOptionsBuilder) List() []func(*ClientBulkWriteOptions) error { + return b.Opts +} + +// SetComment sets the value for the Comment field. Specifies a string or document that will be included in +// server logs, profiling logs, and currentOp queries to help tracethe operation. The default value is nil, +// which means that no comment will be included in the logs. +func (b *ClientBulkWriteOptionsBuilder) SetComment(comment interface{}) *ClientBulkWriteOptionsBuilder { + b.Opts = append(b.Opts, func(opts *ClientBulkWriteOptions) error { + opts.Comment = comment + + return nil + }) + + return b +} + +// SetOrdered sets the value for the Ordered field. If true, no writes will be executed after one fails. +// The default value is true. +func (b *ClientBulkWriteOptionsBuilder) SetOrdered(ordered bool) *ClientBulkWriteOptionsBuilder { + b.Opts = append(b.Opts, func(opts *ClientBulkWriteOptions) error { + opts.Ordered = &ordered + + return nil + }) + + return b +} + +// SetBypassDocumentValidation sets the value for the BypassDocumentValidation field. If true, writes +// executed as part of the operation will opt out of document-level validation on the server. The default +// value is false. See https://www.mongodb.com/docs/manual/core/schema-validation/ for more information +// about document validation. +func (b *ClientBulkWriteOptionsBuilder) SetBypassDocumentValidation(bypass bool) *ClientBulkWriteOptionsBuilder { + b.Opts = append(b.Opts, func(opts *ClientBulkWriteOptions) error { + opts.BypassDocumentValidation = &bypass + + return nil + }) + + return b +} + +// SetLet sets the value for the Let field. Let specifies parameters for all update and delete commands in the BulkWrite. +// This must be a document mapping parameter names to values. Values must be constant or closed expressions that do not +// reference document fields. Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). +func (b *ClientBulkWriteOptionsBuilder) SetLet(let interface{}) *ClientBulkWriteOptionsBuilder { + b.Opts = append(b.Opts, func(opts *ClientBulkWriteOptions) error { + opts.Let = &let + + return nil + }) + + return b +} + +// SetWriteConcern sets the value for the WriteConcern field. Specifies the write concern for +// operations in the transaction. The default value is nil, which means that the default +// write concern of the session used to start the transaction will be used. +func (b *ClientBulkWriteOptionsBuilder) SetWriteConcern(wc *writeconcern.WriteConcern) *ClientBulkWriteOptionsBuilder { + b.Opts = append(b.Opts, func(opts *ClientBulkWriteOptions) error { + opts.WriteConcern = wc + + return nil + }) + + return b +} + +// SetVerboseResults sets the value for the VerboseResults field. Specifies whether detailed +// results for each successful operation should be included in the returned BulkWriteResult. +// The defaults value is false. +func (b *ClientBulkWriteOptionsBuilder) SetVerboseResults(verboseResults bool) *ClientBulkWriteOptionsBuilder { + b.Opts = append(b.Opts, func(opts *ClientBulkWriteOptions) error { + opts.VerboseResults = &verboseResults + + return nil + }) + + return b +} diff --git a/mongo/results.go b/mongo/results.go index 887c0d646e..92474a8466 100644 --- a/mongo/results.go +++ b/mongo/results.go @@ -13,6 +13,51 @@ import ( "go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation" ) +// ClientBulkWriteResult is the result type returned by a client-level BulkWrite operation. +type ClientBulkWriteResult struct { + // The number of documents inserted. + InsertedCount int64 + + // The number of documents matched by filters in update and replace operations. + MatchedCount int64 + + // The number of documents modified by update and replace operations. + ModifiedCount int64 + + // The number of documents deleted. + DeletedCount int64 + + // The number of documents upserted by update and replace operations. + UpsertedCount int64 + + // A map of operation index to the _id of each inserted document. + InsertResults map[int]ClientInsertResult + + // A map of operation index to the _id of each updated document. + UpdateResults map[int]ClientUpdateResult + + // A map of operation index to the _id of each deleted document. + DeleteResults map[int]ClientDeleteResult +} + +// ClientInsertResult is the result type returned by a client-level bulk write of InsertOne operation. +type ClientInsertResult struct { + // The _id of the inserted document. A value generated by the driver will be of type primitive.ObjectID. + InsertedID interface{} +} + +// ClientUpdateResult is the result type returned from a client-level bulk write of UpdateOne, UpdateMany, and ReplaceOne operation. +type ClientUpdateResult struct { + MatchedCount int64 // The number of documents matched by the filter. + ModifiedCount int64 // The number of documents modified by the operation. + UpsertedID interface{} // The _id field of the upserted document, or nil if no upsert was done. +} + +// ClientDeleteResult is the result type returned by a client-level bulk write DeleteOne and DeleteMany operation. +type ClientDeleteResult struct { + DeletedCount int64 // The number of documents deleted. +} + // BulkWriteResult is the result type returned by a BulkWrite operation. type BulkWriteResult struct { // The number of documents inserted. diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index 47abe3e5f2..f444739661 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -75,25 +75,29 @@ type CursorResponse struct { postBatchResumeToken bsoncore.Document } -// NewCursorResponse constructs a cursor response from the given response and -// server. If the provided database response does not contain a cursor, it -// returns ErrNoCursor. -// -// NewCursorResponse can be used within the ProcessResponse method for an operation. -func NewCursorResponse(info ResponseInfo) (CursorResponse, error) { - response := info.ServerResponse +// ExtractCursorDocument retrieves cursor document from a database response. If the +// provided response does not contain a cursor, it returns ErrNoCursor. +func ExtractCursorDocument(response bsoncore.Document) (bsoncore.Document, error) { cur, err := response.LookupErr("cursor") if errors.Is(err, bsoncore.ErrElementNotFound) { - return CursorResponse{}, ErrNoCursor + return nil, ErrNoCursor } if err != nil { - return CursorResponse{}, fmt.Errorf("error getting cursor from database response: %w", err) + return nil, fmt.Errorf("error getting cursor from database response: %w", err) } curDoc, ok := cur.DocumentOK() if !ok { - return CursorResponse{}, fmt.Errorf("cursor should be an embedded document but is BSON type %s", cur.Type) + return nil, fmt.Errorf("cursor should be an embedded document but is BSON type %s", cur.Type) } - elems, err := curDoc.Elements() + return curDoc, nil +} + +// NewCursorResponse constructs a cursor response from the given cursor document +// extracted from a database response. +// +// NewCursorResponse can be used within the ProcessResponse method for an operation. +func NewCursorResponse(response bsoncore.Document, info ResponseInfo) (CursorResponse, error) { + elems, err := response.Elements() if err != nil { return CursorResponse{}, fmt.Errorf("error getting elements from cursor: %w", err) } @@ -120,15 +124,17 @@ func NewCursorResponse(info ResponseInfo) (CursorResponse, error) { curresp.Database = database curresp.Collection = collection case "id": - curresp.ID, ok = elem.Value().Int64OK() + id, ok := elem.Value().Int64OK() if !ok { return CursorResponse{}, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type) } + curresp.ID = id case "postBatchResumeToken": - curresp.postBatchResumeToken, ok = elem.Value().DocumentOK() + token, ok := elem.Value().DocumentOK() if !ok { return CursorResponse{}, fmt.Errorf("post batch resume token should be a document but it is a BSON %s", elem.Value().Type) } + curresp.postBatchResumeToken = token } } @@ -399,8 +405,7 @@ func (bc *BatchCursor) getMore(ctx context.Context) { }, Database: bc.database, Deployment: bc.getOperationDeployment(), - ProcessResponseFn: func(info ResponseInfo) error { - response := info.ServerResponse + ProcessResponseFn: func(_ context.Context, response bsoncore.Document, _ ResponseInfo) error { id, ok := response.Lookup("cursor", "id").Int64OK() if !ok { return fmt.Errorf("cursor.id should be an int64 but is a BSON %s", response.Lookup("cursor", "id").Type) diff --git a/x/mongo/driver/batches.go b/x/mongo/driver/batches.go index 73812f0587..4f096616a7 100644 --- a/x/mongo/driver/batches.go +++ b/x/mongo/driver/batches.go @@ -7,70 +7,114 @@ package driver import ( - "errors" + "io" + "strconv" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" ) -// ErrDocumentTooLarge occurs when a document that is larger than the maximum size accepted by a -// server is passed to an insert command. -var ErrDocumentTooLarge = errors.New("an inserted document is too large") - // Batches contains the necessary information to batch split an operation. This is only used for write // operations. type Batches struct { Identifier string Documents []bsoncore.Document - Current []bsoncore.Document Ordered *bool -} - -// Valid returns true if Batches contains both an identifier and the length of Documents is greater -// than zero. -func (b *Batches) Valid() bool { return b != nil && b.Identifier != "" && len(b.Documents) > 0 } - -// ClearBatch clears the Current batch. This must be called before AdvanceBatch will advance to the -// next batch. -func (b *Batches) ClearBatch() { b.Current = b.Current[:0] } -// AdvanceBatch splits the next batch using maxCount and targetBatchSize. This method will do nothing if -// the current batch has not been cleared. We do this so that when this is called during execute we -// can call it without first needing to check if we already have a batch, which makes the code -// simpler and makes retrying easier. -// The maxDocSize parameter is used to check that any one document is not too large. If the first document is bigger -// than targetBatchSize but smaller than maxDocSize, a batch of size 1 containing that document will be created. -func (b *Batches) AdvanceBatch(maxCount, targetBatchSize, maxDocSize int) error { - if len(b.Current) > 0 { - return nil - } + offset int +} - if maxCount <= 0 { - maxCount = 1 +// AppendBatchSequence appends dst with document sequence of batches as long as the limits of max count, max +// document size, or total size allows. It returns the number of batches appended, the new appended slice, and +// any error raised. It returns the origenal input slice if nothing can be appends within the limits. +func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, _ int) (int, []byte, error) { + if b.Size() == 0 { + return 0, dst, io.EOF } - - splitAfter := 0 - size := 0 - for i, doc := range b.Documents { - if i == maxCount { + l := len(dst) + var idx int32 + dst = wiremessage.AppendMsgSectionType(dst, wiremessage.DocumentSequence) + idx, dst = bsoncore.ReserveLength(dst) + dst = append(dst, b.Identifier...) + dst = append(dst, 0x00) + var size int + var n int + for i := b.offset; i < len(b.Documents); i++ { + if n == maxCount { break } + doc := b.Documents[i] if len(doc) > maxDocSize { - return ErrDocumentTooLarge + break } - if size+len(doc) > targetBatchSize { + size += len(doc) + if size > maxDocSize { break } + dst = append(dst, doc...) + n++ + } + if n == 0 { + return 0, dst[:l], nil + } + dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) + return n, dst, nil +} +// AppendBatchArray appends dst with array of batches as long as the limits of max count, max document size, or +// total size allows. It returns the number of batches appended, the new appended slice, and any error raised. It +// returns the origenal input slice if nothing can be appends within the limits. +func (b *Batches) AppendBatchArray(dst []byte, maxCount, maxDocSize, _ int) (int, []byte, error) { + if b.Size() == 0 { + return 0, dst, io.EOF + } + l := len(dst) + aidx, dst := bsoncore.AppendArrayElementStart(dst, b.Identifier) + var size int + var n int + for i := b.offset; i < len(b.Documents); i++ { + if n == maxCount { + break + } + doc := b.Documents[i] + if len(doc) > maxDocSize { + break + } size += len(doc) - splitAfter++ + if size > maxDocSize { + break + } + dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(n), doc) + n++ + } + if n == 0 { + return 0, dst[:l], nil } + var err error + dst, err = bsoncore.AppendArrayEnd(dst, aidx) + if err != nil { + return 0, nil, err + } + return n, dst, nil +} + +// IsOrdered indicates if the batches are ordered. +func (b *Batches) IsOrdered() *bool { + return b.Ordered +} - // if there are no documents, take the first one. - // this can happen if there is a document that is smaller than maxDocSize but greater than targetBatchSize. - if splitAfter == 0 { - splitAfter = 1 +// AdvanceBatches advances the batches with the given input. +func (b *Batches) AdvanceBatches(n int) { + b.offset += n + if b.offset > len(b.Documents) { + b.offset = len(b.Documents) } +} - b.Current, b.Documents = b.Documents[:splitAfter], b.Documents[splitAfter:] - return nil +// Size returns the size of batches remained. +func (b *Batches) Size() int { + if b.offset > len(b.Documents) { + return 0 + } + return len(b.Documents) - b.offset } diff --git a/x/mongo/driver/batches_test.go b/x/mongo/driver/batches_test.go index 3bd17affcc..95cdb674de 100644 --- a/x/mongo/driver/batches_test.go +++ b/x/mongo/driver/batches_test.go @@ -9,129 +9,93 @@ package driver import ( "testing" - "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" ) -func TestBatches(t *testing.T) { - t.Run("Valid", func(t *testing.T) { - testCases := []struct { - name string - batches *Batches - want bool - }{ - {"nil", nil, false}, - {"missing identifier", &Batches{}, false}, - {"no documents", &Batches{Identifier: "documents"}, false}, - {"valid", &Batches{Identifier: "documents", Documents: make([]bsoncore.Document, 5)}, true}, - } +func newTestBatches(t *testing.T) *Batches { + t.Helper() + return &Batches{ + Identifier: "foobar", + Documents: []bsoncore.Document{ + []byte("Lorem ipsum dolor sit amet"), + []byte("consectetur adipiscing elit"), + }, + } +} - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - want := tc.want - got := tc.batches.Valid() - if got != want { - t.Errorf("Did not get expected result from Valid. got %t; want %t", got, want) - } - }) - } - }) - t.Run("ClearBatch", func(t *testing.T) { - batches := &Batches{Identifier: "documents", Current: make([]bsoncore.Document, 2, 10)} - if len(batches.Current) != 2 { - t.Fatalf("Length of current batch should be 2, but is %d", len(batches.Current)) - } - batches.ClearBatch() - if len(batches.Current) != 0 { - t.Fatalf("Length of current batch should be 0, but is %d", len(batches.Current)) - } +func TestAdvancing(t *testing.T) { + batches := newTestBatches(t) + batches.AdvanceBatches(3) + size := batches.Size() + assert.Equal(t, 0, size, "expected Size(): %d, got: %d", 1, size) +} + +func TestAppendBatchSequence(t *testing.T) { + t.Run("Append 0", func(t *testing.T) { + batches := newTestBatches(t) + + got := []byte{42} + var n int + var err error + n, got, err = batches.AppendBatchSequence(got, 2, len(batches.Documents[0])-1, 0) + assert.NoError(t, err) + assert.Equal(t, 0, n) + + assert.Equal(t, []byte{42}, got) }) - t.Run("AdvanceBatch", func(t *testing.T) { - documents := make([]bsoncore.Document, 0) - for i := 0; i < 5; i++ { - doc := make(bsoncore.Document, 100) - documents = append(documents, doc) - } + t.Run("Append 1", func(t *testing.T) { + batches := newTestBatches(t) - testCases := []struct { - name string - batches *Batches - maxCount int - targetBatchSize int - maxDocSize int - err error - want *Batches - }{ - { - "current batch non-zero", - &Batches{Current: make([]bsoncore.Document, 2, 10)}, - 0, 0, 0, nil, - &Batches{Current: make([]bsoncore.Document, 2, 10)}, - }, - { - // all of the documents in the batch fit in targetBatchSize so the batch is created successfully - "documents fit in targetBatchSize", - &Batches{Documents: documents}, - 10, 600, 1000, nil, - &Batches{Documents: documents[:0], Current: documents[0:]}, - }, - { - // the first doc is bigger than targetBatchSize but smaller than maxDocSize so it is taken alone - "first document larger than targetBatchSize, smaller than maxDocSize", - &Batches{Documents: documents}, - 10, 5, 100, nil, - &Batches{Documents: documents[1:], Current: documents[:1]}, - }, - } + got := []byte{42} + var n int + var err error + n, got, err = batches.AppendBatchSequence(got, 2, len(batches.Documents[0]), 0) + assert.NoError(t, err) + assert.Equal(t, 1, n) - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := tc.batches.AdvanceBatch(tc.maxCount, tc.targetBatchSize, tc.maxDocSize) - if !cmp.Equal(err, tc.err, cmp.Comparer(compareErrors)) { - t.Errorf("Errors do not match. got %v; want %v", err, tc.err) - } - if !cmp.Equal(tc.batches, tc.want) { - t.Errorf("Batches is not in correct state after AdvanceBatch. got %v; want %v", tc.batches, tc.want) - } - }) - } + var idx int32 + dst := []byte{42} + dst = wiremessage.AppendMsgSectionType(dst, wiremessage.DocumentSequence) + idx, dst = bsoncore.ReserveLength(dst) + dst = append(dst, "foobar"...) + dst = append(dst, 0x00) + dst = append(dst, "Lorem ipsum dolor sit amet"...) + dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) + assert.Equal(t, dst, got) + }) +} - t.Run("middle document larger than targetBatchSize, smaller than maxDocSize", func(t *testing.T) { - // a batch is made but one document is too big, so everything before it is taken. - // on the second call to AdvanceBatch, only the large document is taken +func TestAppendBatchArray(t *testing.T) { + t.Run("Append 0", func(t *testing.T) { + batches := newTestBatches(t) - middleLargeDoc := make([]bsoncore.Document, 0) - for i := 0; i < 5; i++ { - doc := make(bsoncore.Document, 100) - middleLargeDoc = append(middleLargeDoc, doc) - } - largeDoc := make(bsoncore.Document, 900) - middleLargeDoc[2] = largeDoc - batches := &Batches{Documents: middleLargeDoc} - maxCount := 10 - targetSize := 600 - maxDocSize := 1000 + got := []byte{42} + var n int + var err error + n, got, err = batches.AppendBatchArray(got, 2, len(batches.Documents[0])-1, 0) + assert.NoError(t, err) + assert.Equal(t, 0, n) - // first batch should take first 2 docs (size 100 each) - err := batches.AdvanceBatch(maxCount, targetSize, maxDocSize) - assert.Nil(t, err, "AdvanceBatch error: %v", err) - want := &Batches{Current: middleLargeDoc[:2], Documents: middleLargeDoc[2:]} - assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches) + assert.Equal(t, []byte{42}, got) + }) + t.Run("Append 1", func(t *testing.T) { + batches := newTestBatches(t) - // second batch should take single large doc (size 900) - batches.ClearBatch() - err = batches.AdvanceBatch(maxCount, targetSize, maxDocSize) - assert.Nil(t, err, "AdvanceBatch error: %v", err) - want = &Batches{Current: middleLargeDoc[2:3], Documents: middleLargeDoc[3:]} - assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches) + got := []byte{42} + var n int + var err error + n, got, err = batches.AppendBatchArray(got, 2, len(batches.Documents[0]), 0) + assert.NoError(t, err) + assert.Equal(t, 1, n) - // last batch should take last 2 docs (size 100 each) - batches.ClearBatch() - err = batches.AdvanceBatch(maxCount, targetSize, maxDocSize) - assert.Nil(t, err, "AdvanceBatch error: %v", err) - want = &Batches{Current: middleLargeDoc[3:], Documents: middleLargeDoc[:0]} - assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches) - }) + var idx int32 + dst := []byte{42} + idx, dst = bsoncore.AppendArrayElementStart(dst, "foobar") + dst = bsoncore.AppendDocumentElement(dst, "0", []byte("Lorem ipsum dolor sit amet")) + dst, err = bsoncore.AppendArrayEnd(dst, idx) + assert.NoError(t, err) + assert.Equal(t, dst, got) }) } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 968b2f258c..fdf8732f21 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -51,6 +51,9 @@ var ( ErrEmptyReadConcern = errors.New("a read concern must have at least one field set") // ErrEmptyWriteConcern indicates that a write concern has no fields set. ErrEmptyWriteConcern = errors.New("a write concern must have at least one field set") + // ErrDocumentTooLarge occurs when a document that is larger than the maximum size accepted by a + // server is passed to an insert command. + ErrDocumentTooLarge = errors.New("an inserted document is too large") // errDatabaseNameEmpty occurs when a database name is not provided. errDatabaseNameEmpty = errors.New("database name cannot be empty") // errNegativeW indicates that a negative integer `w` field was specified. @@ -61,7 +64,7 @@ var ( const ( // maximum BSON object size when in-use encryption is enabled - cryptMaxBsonObjectSize uint32 = 2097152 + cryptMaxBsonObjectSize int = 2097152 // minimum wire version necessary to use automatic encryption cryptMinWireVersion int32 = 8 // minimum wire version necessary to use read snapshots @@ -100,16 +103,17 @@ type opReply struct { // startedInformation keeps track of all of the information necessary for monitoring started events. type startedInformation struct { - cmd bsoncore.Document - requestID int32 - cmdName string - documentSequenceIncluded bool - connID string - driverConnectionID int64 - serverConnID *int64 - redacted bool - serviceID *bson.ObjectID - serverAddress address.Address + cmd bsoncore.Document + requestID int32 + cmdName string + documentSequence []byte + processedBatches int + connID string + driverConnectionID int64 + serverConnID *int64 + redacted bool + serviceID *bson.ObjectID + serverAddress address.Address } // finishedInformation keeps track of all of the information necessary for monitoring success and failure events. @@ -143,28 +147,27 @@ func (info finishedInformation) success() bool { // ResponseInfo contains the context required to parse a server response. type ResponseInfo struct { - ServerResponse bsoncore.Document Server Server Connection *mnet.Connection ConnectionDescription description.Server CurrentIndex int + Error error } -func redactStartedInformationCmd(op Operation, info startedInformation) bson.Raw { +func redactStartedInformationCmd(info startedInformation) bson.Raw { var cmdCopy bson.Raw // Make a copy of the command. Redact if the command is security // sensitive and cannot be monitored. If there was a type 1 payload for // the current batch, convert it to a BSON array if !info.redacted { - cmdCopy = make([]byte, len(info.cmd)) - copy(cmdCopy, info.cmd) + cmdCopy = make([]byte, 0, len(info.cmd)) + cmdCopy = append(cmdCopy, info.cmd...) - if info.documentSequenceIncluded { + if len(info.documentSequence) > 0 { // remove 0 byte at end cmdCopy = cmdCopy[:len(info.cmd)-1] - cmdCopy = op.addBatchArray(cmdCopy) - + cmdCopy = append(cmdCopy, info.documentSequence...) // add back 0 byte and update length cmdCopy, _ = bsoncore.AppendDocumentEnd(cmdCopy, 0) } @@ -210,7 +213,7 @@ type Operation struct { // ProcessResponseFn is called after a response to the command is returned. The server is // provided for types like Cursor that are required to run subsequent commands using the same // server. - ProcessResponseFn func(ResponseInfo) error + ProcessResponseFn func(context.Context, bsoncore.Document, ResponseInfo) error // Selector is the server selector that's used during both initial server selection and // subsequent selection for retries. Depending on the Deployment implementation, the @@ -268,7 +271,13 @@ type Operation struct { // has more documents than can fit in a single command. This should only be specified for // commands that are batch compatible. For more information, please refer to the definition of // Batches. - Batches *Batches + Batches interface { + AppendBatchSequence(dst []byte, maxCount int, maxDocSize int, totalSize int) (int, []byte, error) + AppendBatchArray(dst []byte, maxCount int, maxDocSize int, totalSize int) (int, []byte, error) + IsOrdered() *bool + AdvanceBatches(n int) + Size() int + } // Legacy sets the legacy type for this operation. There are only 3 types that require legacy // support: find, getMore, and killCursors. For more information about LegacyOperationKind, @@ -532,12 +541,12 @@ func (op Operation) Execute(ctx context.Context) error { retries = -1 } } - } - // If context is a Timeout context, automatically set retries to -1 (infinite) if retrying is - // enabled. - retryEnabled := op.RetryMode != nil && op.RetryMode.Enabled() - if csot.IsTimeoutContext(ctx) && retryEnabled { - retries = -1 + + // If context is a Timeout context, automatically set retries to -1 (infinite) if retrying is + // enabled. + if csot.IsTimeoutContext(ctx) && op.RetryMode.Enabled() { + retries = -1 + } } var srvr Server @@ -546,7 +555,6 @@ func (op Operation) Execute(ctx context.Context) error { var operationErr WriteCommandError var prevErr error var prevIndefiniteErr error - batching := op.Batches.Valid() retrySupported := false first := true currIndex := 0 @@ -585,7 +593,7 @@ func (op Operation) Execute(ctx context.Context) error { if conn != nil { // If we are dealing with a sharded cluster, then mark the failed server // as "deprioritized". - if desc := conn.Description; desc != nil && op.Deployment.Kind() == description.TopologyKindSharded { + if op.Deployment.Kind() == description.TopologyKindSharded { deprioritizedServers = []description.Server{conn.Description()} } @@ -672,14 +680,10 @@ func (op Operation) Execute(ctx context.Context) error { // Calling IncrementTxnNumber() for server descriptions or topologies that do not // support retries (e.g. standalone topologies) will cause server errors. Only do this // check for the first attempt to keep retried writes in the same transaction. - if retrySupported && op.RetryMode != nil && op.Type == Write && op.Client != nil { - op.Client.RetryWrite = false - if op.RetryMode.Enabled() { - op.Client.RetryWrite = true - if !op.Client.Committing && !op.Client.Aborting { - op.Client.IncrementTxnNumber() - } - } + retryEnabled := op.RetryMode != nil && op.RetryMode.Enabled() + needToIncrease := op.Client != nil && !op.Client.Committing && !op.Client.Aborting + if retrySupported && op.Type == Write && retryEnabled && needToIncrease { + op.Client.IncrementTxnNumber() } first = false @@ -702,30 +706,14 @@ func (op Operation) Execute(ctx context.Context) error { Kind: op.Deployment.Kind(), } - if batching { - targetBatchSize := desc.MaxDocumentSize - maxDocSize := desc.MaxDocumentSize - if op.shouldEncrypt() { - // For in-use encryption, we want the batch to be split at 2 MiB instead of 16MiB. - // If there's only one document in the batch, it can be up to 16MiB, so we set target batch size to - // 2MiB but max document size to 16MiB. This will allow the AdvanceBatch call to create a batch - // with a single large document. - targetBatchSize = cryptMaxBsonObjectSize - } - - err = op.Batches.AdvanceBatch(int(desc.MaxBatchCount), int(targetBatchSize), int(maxDocSize)) - if err != nil { - // TODO(GODRIVER-982): Should we also be returning operationErr? - return err - } - } - + var moreToCome bool var startedInfo startedInformation - *wm, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) + *wm, moreToCome, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) if err != nil { return err } + retryEnabled := op.RetryMode != nil && op.RetryMode.Enabled() // set extra data and send event if possible startedInfo.connID = conn.ID() @@ -747,9 +735,6 @@ func (op Operation) Execute(ctx context.Context) error { op.publishStartedEvent(ctx, startedInfo) - // get the moreToCome flag information before we compress - moreToCome := wiremessage.IsMsgMoreToCome(*wm) - // compress wiremessage if allowed if compressor := conn.Compressor; compressor != nil && op.canCompress(startedInfo.cmdName) { b := memoryPool.Get().(*[]byte) @@ -812,7 +797,6 @@ func (op Operation) Execute(ctx context.Context) error { // TODO(GODRIVER-2579): When refactoring the "Execute" method, consider creating a separate method for the // error handling logic below. This will remove the necessity of the "checkError" goto label. checkError: - var perr error switch tt := err.(type) { case WriteCommandError: if e := err.(WriteCommandError); retrySupported && op.Type == Write && e.UnsupportedStorageEngine() { @@ -833,7 +817,7 @@ func (op Operation) Execute(ctx context.Context) error { // If retries are supported for the current operation on the first server description, // the error is considered retryable, and there are retries remaining (negative retries // means retry indefinitely), then retry the operation. - if retrySupported && retryableErr && retries != 0 { + if retrySupported && retryEnabled && retryableErr && retries != 0 { if op.Client != nil && op.Client.Committing { // Apply majority write concern for retries op.Client.UpdateCommitTransactionWriteConcern() @@ -856,25 +840,26 @@ func (op Operation) Execute(ctx context.Context) error { // If the operation isn't being retried, process the response if op.ProcessResponseFn != nil { info := ResponseInfo{ - ServerResponse: res, Server: srvr, Connection: conn, ConnectionDescription: desc.Server, CurrentIndex: currIndex, + Error: tt, } - _ = op.ProcessResponseFn(info) - } - - if batching && len(tt.WriteErrors) > 0 && currIndex > 0 { - for i := range tt.WriteErrors { - tt.WriteErrors[i].Index += int64(currIndex) - } + _ = op.ProcessResponseFn(ctx, res, info) } // If batching is enabled and either ordered is the default (which is true) or // explicitly set to true and we have write errors, return the errors. - if batching && (op.Batches.Ordered == nil || *op.Batches.Ordered) && len(tt.WriteErrors) > 0 { - return tt + if op.Batches != nil && len(tt.WriteErrors) > 0 { + if currIndex > 0 { + for i := range tt.WriteErrors { + tt.WriteErrors[i].Index += int64(currIndex) + } + } + if isOrdered := op.Batches.IsOrdered(); isOrdered == nil || *isOrdered { + return tt + } } if op.Client != nil && op.Client.Committing && tt.WriteConcernError != nil { // When running commitTransaction we return WriteConcernErrors as an Error. @@ -952,7 +937,7 @@ func (op Operation) Execute(ctx context.Context) error { // If retries are supported for the current operation on the first server description, // the error is considered retryable, and there are retries remaining (negative retries // means retry indefinitely), then retry the operation. - if retrySupported && retryableErr && retries != 0 { + if retrySupported && retryEnabled && retryableErr && retries != 0 { if op.Client != nil && op.Client.Committing { // Apply majority write concern for retries op.Client.UpdateCommitTransactionWriteConcern() @@ -975,13 +960,13 @@ func (op Operation) Execute(ctx context.Context) error { // If the operation isn't being retried, process the response if op.ProcessResponseFn != nil { info := ResponseInfo{ - ServerResponse: res, Server: srvr, Connection: conn, ConnectionDescription: desc.Server, CurrentIndex: currIndex, + Error: tt, } - _ = op.ProcessResponseFn(info) + _ = op.ProcessResponseFn(ctx, res, info) } if op.Client != nil && op.Client.Committing && (retryableErr || tt.Code == 50) { @@ -995,27 +980,27 @@ func (op Operation) Execute(ctx context.Context) error { } if op.ProcessResponseFn != nil { info := ResponseInfo{ - ServerResponse: res, Server: srvr, Connection: conn, ConnectionDescription: desc.Server, CurrentIndex: currIndex, + Error: tt, + } + perr := op.ProcessResponseFn(ctx, res, info) + if perr != nil { + return perr } - perr = op.ProcessResponseFn(info) - } - if perr != nil { - return perr } default: if op.ProcessResponseFn != nil { info := ResponseInfo{ - ServerResponse: res, Server: srvr, Connection: conn, ConnectionDescription: desc.Server, CurrentIndex: currIndex, + Error: tt, } - _ = op.ProcessResponseFn(info) + _ = op.ProcessResponseFn(ctx, res, info) } return err } @@ -1023,23 +1008,22 @@ func (op Operation) Execute(ctx context.Context) error { // If we're batching and there are batches remaining, advance to the next batch. This isn't // a retry, so increment the transaction number, reset the retries number, and don't set // server or connection to nil to continue using the same connection. - if batching && len(op.Batches.Documents) > 0 { + if op.Batches != nil && op.Batches.Size() > startedInfo.processedBatches { // If retries are supported for the current operation on the current server description, // the session isn't nil, and client retries are enabled, increment the txn number. // Calling IncrementTxnNumber() for server descriptions or topologies that do not // support retries (e.g. standalone topologies) will cause server errors. - if retrySupported && op.Client != nil && op.RetryMode != nil { - if op.RetryMode.Enabled() { - op.Client.IncrementTxnNumber() - } + if retrySupported && op.Client != nil && retryEnabled { + op.Client.IncrementTxnNumber() + // Reset the retries number for RetryOncePerCommand unless context is a Timeout context, in // which case retries should remain as -1 (as many times as possible). if *op.RetryMode == RetryOncePerCommand && !csot.IsTimeoutContext(ctx) { retries = 1 } } - currIndex += len(op.Batches.Current) - op.Batches.ClearBatch() + currIndex += startedInfo.processedBatches + op.Batches.AdvanceBatches(startedInfo.processedBatches) continue } break @@ -1196,26 +1180,14 @@ func (Operation) decompressWireMessage(wm []byte) (wiremessage.OpCode, []byte, e return opcode, uncompressed, nil } -func (op Operation) addBatchArray(dst []byte) []byte { - aidx, dst := bsoncore.AppendArrayElementStart(dst, op.Batches.Identifier) - for i, doc := range op.Batches.Current { - dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(i), doc) - } - dst, _ = bsoncore.AppendArrayEnd(dst, aidx) - return dst -} - func (op Operation) createLegacyHandshakeWireMessage( ctx context.Context, maxTimeMS int64, dst []byte, desc description.SelectedServer, -) ([]byte, startedInformation, error) { - var info startedInformation + cmdFn func([]byte, description.SelectedServer) ([]byte, error), +) ([]byte, []byte, error) { flags := op.secondaryOK(desc) - var wmindex int32 - info.requestID = wiremessage.NextRequestID() - wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery) dst = wiremessage.AppendQueryFlags(dst, flags) dollarCmd := [...]byte{'.', '$', 'c', 'm', 'd'} @@ -1230,35 +1202,31 @@ func (op Operation) createLegacyHandshakeWireMessage( wrapper := int32(-1) rp, err := op.createReadPref(desc, true) if err != nil { - return dst, info, err + return dst, nil, err } if len(rp) > 0 { wrapper, dst = bsoncore.AppendDocumentStart(dst) dst = bsoncore.AppendHeader(dst, bsoncore.TypeEmbeddedDocument, "$query") } idx, dst := bsoncore.AppendDocumentStart(dst) - dst, err = op.CommandFn(dst, desc) + dst, err = cmdFn(dst, desc) if err != nil { - return dst, info, err - } - - if op.Batches != nil && len(op.Batches.Current) > 0 { - dst = op.addBatchArray(dst) + return dst, nil, err } dst, err = op.addReadConcern(dst, desc) if err != nil { - return dst, info, err + return dst, nil, err } dst, err = op.addWriteConcern(ctx, dst, desc) if err != nil { - return dst, info, err + return dst, nil, err } - dst, err = op.addSession(dst, desc) + dst, err = op.addSession(dst, desc, false) if err != nil { - return dst, info, err + return dst, nil, err } dst = op.addClusterTime(dst, desc) @@ -1270,19 +1238,18 @@ func (op Operation) createLegacyHandshakeWireMessage( } dst, _ = bsoncore.AppendDocumentEnd(dst, idx) - // Command monitoring only reports the document inside $query - info.cmd = dst[idx:] if len(rp) > 0 { + idx = wrapper var err error dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp) - dst, err = bsoncore.AppendDocumentEnd(dst, wrapper) + dst, err = bsoncore.AppendDocumentEnd(dst, idx) if err != nil { - return dst, info, err + return dst, nil, err } } - return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil + return dst, dst[idx:], nil } func (op Operation) createMsgWireMessage( @@ -1291,45 +1258,40 @@ func (op Operation) createMsgWireMessage( dst []byte, desc description.SelectedServer, conn *mnet.Connection, - requestID int32, -) ([]byte, startedInformation, error) { - var info startedInformation + cmdFn func([]byte, description.SelectedServer) ([]byte, error), +) ([]byte, []byte, error) { var flags wiremessage.MsgFlag - var wmindex int32 - // We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either - // aren't batching or we are encoding the last batch. - if op.WriteConcern != nil && !op.WriteConcern.Acknowledged() && (op.Batches == nil || len(op.Batches.Documents) == 0) { - flags = wiremessage.MoreToCome - } // Set the ExhaustAllowed flag if the connection supports streaming. This will tell the server that it can // respond with the MoreToCome flag and then stream responses over this connection. if streamer := conn.Streamer; streamer != nil && streamer.SupportsStreaming() { - flags |= wiremessage.ExhaustAllowed + flags = wiremessage.ExhaustAllowed } - - info.requestID = requestID - wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg) dst = wiremessage.AppendMsgFlags(dst, flags) // Body dst = wiremessage.AppendMsgSectionType(dst, wiremessage.SingleDocument) idx, dst := bsoncore.AppendDocumentStart(dst) - dst, err := op.addCommandFields(ctx, dst, desc) + var err error + dst, err = cmdFn(dst, desc) if err != nil { - return dst, info, err + return dst, nil, err } dst, err = op.addReadConcern(dst, desc) if err != nil { - return dst, info, err + return dst, nil, err } dst, err = op.addWriteConcern(ctx, dst, desc) if err != nil { - return dst, info, err + return dst, nil, err + } + retryWrite := false + if op.retryable(conn.Description()) && op.RetryMode != nil && op.RetryMode.Enabled() { + retryWrite = true } - dst, err = op.addSession(dst, desc) + dst, err = op.addSession(dst, desc, retryWrite) if err != nil { - return dst, info, err + return dst, nil, err } dst = op.addClusterTime(dst, desc) @@ -1343,34 +1305,15 @@ func (op Operation) createMsgWireMessage( dst = bsoncore.AppendStringElement(dst, "$db", op.Database) rp, err := op.createReadPref(desc, false) if err != nil { - return dst, info, err + return dst, nil, err } if len(rp) > 0 { dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp) } dst, _ = bsoncore.AppendDocumentEnd(dst, idx) - // The command document for monitoring shouldn't include the type 1 payload as a document sequence - info.cmd = dst[idx:] - // add batch as a document sequence if auto encryption is not enabled - // if auto encryption is enabled, the batch will already be an array in the command document - if !op.shouldEncrypt() && op.Batches != nil && len(op.Batches.Current) > 0 { - info.documentSequenceIncluded = true - dst = wiremessage.AppendMsgSectionType(dst, wiremessage.DocumentSequence) - idx, dst = bsoncore.ReserveLength(dst) - - dst = append(dst, op.Batches.Identifier...) - dst = append(dst, 0x00) - - for _, doc := range op.Batches.Current { - dst = append(dst, doc...) - } - - dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) - } - - return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil + return dst, dst[idx:], nil } // isLegacyHandshake returns True if the operation is the first message of @@ -1388,46 +1331,167 @@ func (op Operation) createWireMessage( desc description.SelectedServer, conn *mnet.Connection, requestID int32, -) ([]byte, startedInformation, error) { - if isLegacyHandshake(op, desc) { - return op.createLegacyHandshakeWireMessage(ctx, maxTimeMS, dst, desc) - } +) ([]byte, bool, startedInformation, error) { + var info startedInformation + var wmindex int32 + var err error - return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, requestID) -} + unacknowledged := op.WriteConcern != nil && !op.WriteConcern.Acknowledged() -// addCommandFields adds the fields for a command to the wire message in dst. This assumes that the start of the document -// has already been added and does not add the final 0 byte. -func (op Operation) addCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) { - if !op.shouldEncrypt() { - return op.CommandFn(dst, desc) + fIdx := -1 + isLegacy := isLegacyHandshake(op, desc) + switch { + case isLegacy: + cmdFn := func(dst []byte, desc description.SelectedServer) ([]byte, error) { + info.processedBatches, dst, err = op.addLegacyCommandFields(dst, desc) + return dst, err + } + requestID := wiremessage.NextRequestID() + wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpQuery) + dst, info.cmd, err = op.createLegacyHandshakeWireMessage(ctx, maxTimeMS, dst, desc, cmdFn) + case op.shouldEncrypt(): + if desc.WireVersion.Max < cryptMinWireVersion { + return dst, false, info, errors.New("auto-encryption requires a MongoDB version of 4.2") + } + cmdFn := func(dst []byte, desc description.SelectedServer) ([]byte, error) { + info.processedBatches, dst, err = op.addEncryptCommandFields(ctx, dst, desc) + return dst, err + } + wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpMsg) + fIdx = len(dst) + dst, info.cmd, err = op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, cmdFn) + default: + wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpMsg) + fIdx = len(dst) + + batchOffset := -1 + switch op.Batches.(type) { + case *Batches: + dst, info.cmd, err = op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, op.CommandFn) + if err == nil && op.Batches != nil { + batchOffset = len(dst) + info.processedBatches, dst, err = op.Batches.AppendBatchSequence(dst, + int(desc.MaxBatchCount), int(desc.MaxDocumentSize), int(desc.MaxDocumentSize), + ) + if err != nil { + break + } + if info.processedBatches == 0 { + err = ErrDocumentTooLarge + } + } + default: + var batches []byte + if op.Batches != nil { + maxDocSize := -1 + if unacknowledged { + maxDocSize = int(desc.MaxDocumentSize) + } + info.processedBatches, batches, err = op.Batches.AppendBatchSequence(batches, + int(desc.MaxBatchCount), maxDocSize, int(desc.MaxMessageSize), + ) + if err != nil { + break + } + if info.processedBatches == 0 { + err = ErrDocumentTooLarge + break + } + } + dst, info.cmd, err = op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, op.CommandFn) + if err == nil && len(batches) > 0 { + batchOffset = len(dst) + dst = append(dst, batches...) + } + } + if err == nil && batchOffset > 0 { + for b := dst[batchOffset:]; len(b) > 0; /* nothing */ { + var seq []byte + var ok bool + seq, b, ok = wiremessage.DocumentSequenceToArray(b) + if !ok { + break + } + info.documentSequence = append(info.documentSequence, seq...) + } + } + } + if err != nil { + return nil, false, info, err } - if desc.WireVersion.Max < cryptMinWireVersion { - return dst, errors.New("auto-encryption requires a MongoDB version of 4.2") + var moreToCome bool + // We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either + // aren't batching or we are encoding the last batch. + batching := op.Batches != nil && op.Batches.Size() > info.processedBatches + if fIdx > 0 && unacknowledged && !batching { + dst[fIdx] |= byte(wiremessage.MoreToCome) + moreToCome = true } + info.requestID = requestID + return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), moreToCome, info, nil +} - // create temporary command document - cidx, cmdDst := bsoncore.AppendDocumentStart(nil) +func (op Operation) addEncryptCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) (int, []byte, error) { + idx, cmdDst := bsoncore.AppendDocumentStart(nil) var err error + // create temporary command document cmdDst, err = op.CommandFn(cmdDst, desc) if err != nil { - return dst, err + return 0, nil, err } - // use a BSON array instead of a type 1 payload because mongocryptd will convert to arrays regardless - if op.Batches != nil && len(op.Batches.Current) > 0 { - cmdDst = op.addBatchArray(cmdDst) + var n int + if op.Batches != nil { + if maxBatchCount := int(desc.MaxBatchCount); maxBatchCount > 1 { + n, cmdDst, err = op.Batches.AppendBatchArray(cmdDst, maxBatchCount, cryptMaxBsonObjectSize, cryptMaxBsonObjectSize) + if err != nil { + return 0, nil, err + } + } + if n == 0 { + maxDocumentSize := int(desc.MaxDocumentSize) + n, cmdDst, err = op.Batches.AppendBatchArray(cmdDst, 1, maxDocumentSize, maxDocumentSize) + if err != nil { + return 0, nil, err + } + if n == 0 { + return 0, nil, ErrDocumentTooLarge + } + } + } + cmdDst, err = bsoncore.AppendDocumentEnd(cmdDst, idx) + if err != nil { + return 0, nil, err } - cmdDst, _ = bsoncore.AppendDocumentEnd(cmdDst, cidx) - // encrypt the command encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst) if err != nil { - return dst, err + return 0, nil, err } // append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator) dst = append(dst, encrypted[4:len(encrypted)-1]...) - return dst, nil + return n, dst, nil +} + +func (op Operation) addLegacyCommandFields(dst []byte, desc description.SelectedServer) (int, []byte, error) { + var err error + dst, err = op.CommandFn(dst, desc) + if err != nil { + return 0, nil, err + } + if op.Batches == nil { + return 0, dst, nil + } + var n int + maxDocumentSize := int(desc.MaxDocumentSize) + n, dst, err = op.Batches.AppendBatchArray(dst, int(desc.MaxBatchCount), maxDocumentSize, maxDocumentSize) + if err != nil { + return 0, nil, err + } + if n == 0 { + return 0, nil, ErrDocumentTooLarge + } + return n, dst, nil } // addServerAPI adds the relevant fields for server API specification to the wire message in dst. @@ -1600,7 +1664,7 @@ func (op Operation) addWriteConcern(ctx context.Context, dst []byte, desc descri return append(bsoncore.AppendHeader(dst, bsoncore.Type(typ), "writeConcern"), wcBSON...), nil } -func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]byte, error) { +func (op Operation) addSession(dst []byte, desc description.SelectedServer, retryWrite bool) ([]byte, error) { client := op.Client // If the operation is defined for an explicit session but the server @@ -1618,7 +1682,7 @@ func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]b dst = bsoncore.AppendDocumentElement(dst, "lsid", client.SessionID) var addedTxnNumber bool - if op.Type == Write && client.RetryWrite { + if op.Type == Write && retryWrite { addedTxnNumber = true dst = bsoncore.AppendInt64Element(dst, "txnNumber", op.Client.TxnNumber) } @@ -2033,8 +2097,7 @@ func (op Operation) publishStartedEvent(ctx context.Context, info startedInforma if op.canLogCommandMessage() { host, port, _ := net.SplitHostPort(info.serverAddress.String()) - redactedCmd := redactStartedInformationCmd(op, info) - + redactedCmd := redactStartedInformationCmd(info) formattedCmd := logger.FormatDocument(redactedCmd, op.Logger.MaxDocumentLength) op.Logger.Print(logger.LevelDebug, @@ -2057,7 +2120,7 @@ func (op Operation) publishStartedEvent(ctx context.Context, info startedInforma if op.canPublishStartedEvent() { started := &event.CommandStartedEvent{ - Command: redactStartedInformationCmd(op, info), + Command: redactStartedInformationCmd(info), DatabaseName: op.Database, CommandName: info.cmdName, RequestID: int64(info.requestID), diff --git a/x/mongo/driver/operation/abort_transaction.go b/x/mongo/driver/operation/abort_transaction.go index f6fd1f1ada..f5f1d3f344 100644 --- a/x/mongo/driver/operation/abort_transaction.go +++ b/x/mongo/driver/operation/abort_transaction.go @@ -41,9 +41,8 @@ func NewAbortTransaction() *AbortTransaction { return &AbortTransaction{} } -func (at *AbortTransaction) processResponse(driver.ResponseInfo) error { - var err error - return err +func (at *AbortTransaction) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error { + return nil } // Execute runs this operations and returns an error if the operation did not execute successfully. diff --git a/x/mongo/driver/operation/aggregate.go b/x/mongo/driver/operation/aggregate.go index 5b5fd02192..396a569389 100644 --- a/x/mongo/driver/operation/aggregate.go +++ b/x/mongo/driver/operation/aggregate.go @@ -77,10 +77,12 @@ func (a *Aggregate) ResultCursorResponse() driver.CursorResponse { return a.result } -func (a *Aggregate) processResponse(info driver.ResponseInfo) error { - var err error - - a.result, err = driver.NewCursorResponse(info) +func (a *Aggregate) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + curDoc, err := driver.ExtractCursorDocument(resp) + if err != nil { + return err + } + a.result, err = driver.NewCursorResponse(curDoc, info) return err } diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index b1520f54fa..18dd428cd5 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -82,11 +82,15 @@ func (c *Command) Execute(ctx context.Context) error { CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) { return append(dst, c.command[4:len(c.command)-1]...), nil }, - ProcessResponseFn: func(info driver.ResponseInfo) error { - c.resultResponse = info.ServerResponse + ProcessResponseFn: func(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + c.resultResponse = resp if c.createCursor { - cursorRes, err := driver.NewCursorResponse(info) + curDoc, err := driver.ExtractCursorDocument(resp) + if err != nil { + return err + } + cursorRes, err := driver.NewCursorResponse(curDoc, info) if err != nil { return err } diff --git a/x/mongo/driver/operation/commit_transaction.go b/x/mongo/driver/operation/commit_transaction.go index 572fb5607b..4c9eca5357 100644 --- a/x/mongo/driver/operation/commit_transaction.go +++ b/x/mongo/driver/operation/commit_transaction.go @@ -40,9 +40,8 @@ func NewCommitTransaction() *CommitTransaction { return &CommitTransaction{} } -func (ct *CommitTransaction) processResponse(driver.ResponseInfo) error { - var err error - return err +func (ct *CommitTransaction) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error { + return nil } // Execute runs this operations and returns an error if the operation did not execute successfully. diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index b3d201612a..5ecaa3a936 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -97,9 +97,9 @@ func NewCount() *Count { // Result returns the result of executing this operation. func (c *Count) Result() CountResult { return c.result } -func (c *Count) processResponse(info driver.ResponseInfo) error { +func (c *Count) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - c.result, err = buildCountResult(info.ServerResponse) + c.result, err = buildCountResult(resp) return err } diff --git a/x/mongo/driver/operation/create.go b/x/mongo/driver/operation/create.go index 911ddb9b4b..840d0ba469 100644 --- a/x/mongo/driver/operation/create.go +++ b/x/mongo/driver/operation/create.go @@ -57,7 +57,7 @@ func NewCreate(collectionName string) *Create { } } -func (c *Create) processResponse(driver.ResponseInfo) error { +func (c *Create) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error { return nil } diff --git a/x/mongo/driver/operation/create_indexes.go b/x/mongo/driver/operation/create_indexes.go index e878ae9c29..0380a55a26 100644 --- a/x/mongo/driver/operation/create_indexes.go +++ b/x/mongo/driver/operation/create_indexes.go @@ -91,9 +91,9 @@ func NewCreateIndexes(indexes bsoncore.Document) *CreateIndexes { // Result returns the result of executing this operation. func (ci *CreateIndexes) Result() CreateIndexesResult { return ci.result } -func (ci *CreateIndexes) processResponse(info driver.ResponseInfo) error { +func (ci *CreateIndexes) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - ci.result, err = buildCreateIndexesResult(info.ServerResponse) + ci.result, err = buildCreateIndexesResult(resp) return err } diff --git a/x/mongo/driver/operation/create_search_indexes.go b/x/mongo/driver/operation/create_search_indexes.go index 9b3a305ba9..2e3d09bb29 100644 --- a/x/mongo/driver/operation/create_search_indexes.go +++ b/x/mongo/driver/operation/create_search_indexes.go @@ -93,9 +93,9 @@ func NewCreateSearchIndexes(indexes bsoncore.Document) *CreateSearchIndexes { // Result returns the result of executing this operation. func (csi *CreateSearchIndexes) Result() CreateSearchIndexesResult { return csi.result } -func (csi *CreateSearchIndexes) processResponse(info driver.ResponseInfo) error { +func (csi *CreateSearchIndexes) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - csi.result, err = buildCreateSearchIndexesResult(info.ServerResponse) + csi.result, err = buildCreateSearchIndexesResult(resp) return err } diff --git a/x/mongo/driver/operation/delete.go b/x/mongo/driver/operation/delete.go index fe1f9a202a..39420efc58 100644 --- a/x/mongo/driver/operation/delete.go +++ b/x/mongo/driver/operation/delete.go @@ -80,8 +80,8 @@ func NewDelete(deletes ...bsoncore.Document) *Delete { // Result returns the result of executing this operation. func (d *Delete) Result() DeleteResult { return d.result } -func (d *Delete) processResponse(info driver.ResponseInfo) error { - dr, err := buildDeleteResult(info.ServerResponse) +func (d *Delete) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { + dr, err := buildDeleteResult(resp) d.result.N += dr.N return err } diff --git a/x/mongo/driver/operation/distinct.go b/x/mongo/driver/operation/distinct.go index cfff2f7c31..19e8edd80f 100644 --- a/x/mongo/driver/operation/distinct.go +++ b/x/mongo/driver/operation/distinct.go @@ -75,9 +75,9 @@ func NewDistinct(key string, query bsoncore.Document) *Distinct { // Result returns the result of executing this operation. func (d *Distinct) Result() DistinctResult { return d.result } -func (d *Distinct) processResponse(info driver.ResponseInfo) error { +func (d *Distinct) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - d.result, err = buildDistinctResult(info.ServerResponse) + d.result, err = buildDistinctResult(resp) return err } diff --git a/x/mongo/driver/operation/drop_collection.go b/x/mongo/driver/operation/drop_collection.go index e3cb059a50..5905d6482f 100644 --- a/x/mongo/driver/operation/drop_collection.go +++ b/x/mongo/driver/operation/drop_collection.go @@ -79,9 +79,9 @@ func NewDropCollection() *DropCollection { // Result returns the result of executing this operation. func (dc *DropCollection) Result() DropCollectionResult { return dc.result } -func (dc *DropCollection) processResponse(info driver.ResponseInfo) error { +func (dc *DropCollection) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - dc.result, err = buildDropCollectionResult(info.ServerResponse) + dc.result, err = buildDropCollectionResult(resp) return err } diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index 2514de4233..e57cff72ee 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -73,9 +73,9 @@ func NewDropIndexes(index any) *DropIndexes { // Result returns the result of executing this operation. func (di *DropIndexes) Result() DropIndexesResult { return di.result } -func (di *DropIndexes) processResponse(info driver.ResponseInfo) error { +func (di *DropIndexes) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - di.result, err = buildDropIndexesResult(info.ServerResponse) + di.result, err = buildDropIndexesResult(resp) return err } diff --git a/x/mongo/driver/operation/drop_search_index.go b/x/mongo/driver/operation/drop_search_index.go index 3a6adead32..3b9600e85d 100644 --- a/x/mongo/driver/operation/drop_search_index.go +++ b/x/mongo/driver/operation/drop_search_index.go @@ -69,9 +69,9 @@ func NewDropSearchIndex(index string) *DropSearchIndex { // Result returns the result of executing this operation. func (dsi *DropSearchIndex) Result() DropSearchIndexResult { return dsi.result } -func (dsi *DropSearchIndex) processResponse(info driver.ResponseInfo) error { +func (dsi *DropSearchIndex) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - dsi.result, err = buildDropSearchIndexResult(info.ServerResponse) + dsi.result, err = buildDropSearchIndexResult(resp) return err } diff --git a/x/mongo/driver/operation/end_sessions.go b/x/mongo/driver/operation/end_sessions.go index df44fb44be..9b3ec1e248 100644 --- a/x/mongo/driver/operation/end_sessions.go +++ b/x/mongo/driver/operation/end_sessions.go @@ -39,9 +39,8 @@ func NewEndSessions(sessionIDs bsoncore.Document) *EndSessions { } } -func (es *EndSessions) processResponse(driver.ResponseInfo) error { - var err error - return err +func (es *EndSessions) processResponse(context.Context, bsoncore.Document, driver.ResponseInfo) error { + return nil } // Execute runs this operations and returns an error if the operation did not execute successfully. diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index 21cb92eca0..b607cb14d7 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -78,9 +78,12 @@ func (f *Find) Result(opts driver.CursorOptions) (*driver.BatchCursor, error) { return driver.NewBatchCursor(f.result, f.session, f.clock, opts) } -func (f *Find) processResponse(info driver.ResponseInfo) error { - var err error - f.result, err = driver.NewCursorResponse(info) +func (f *Find) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + curDoc, err := driver.ExtractCursorDocument(resp) + if err != nil { + return err + } + f.result, err = driver.NewCursorResponse(curDoc, info) return err } diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index 9939939f92..505c56b06c 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -112,10 +112,10 @@ func NewFindAndModify(query bsoncore.Document) *FindAndModify { // Result returns the result of executing this operation. func (fam *FindAndModify) Result() FindAndModifyResult { return fam.result } -func (fam *FindAndModify) processResponse(info driver.ResponseInfo) error { +func (fam *FindAndModify) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - fam.result, err = buildFindAndModifyResult(info.ServerResponse) + fam.result, err = buildFindAndModifyResult(resp) return err } diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 4e3749aef4..16655f691d 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -588,8 +588,8 @@ func (h *Hello) createOperation() driver.Operation { CommandFn: h.command, Database: "admin", Deployment: h.d, - ProcessResponseFn: func(info driver.ResponseInfo) error { - h.res = info.ServerResponse + ProcessResponseFn: func(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { + h.res = resp return nil }, ServerAPI: h.serverAPI, @@ -613,8 +613,8 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, CommandFn: h.handshakeCommand, Deployment: deployment, Database: "admin", - ProcessResponseFn: func(info driver.ResponseInfo) error { - h.res = info.ServerResponse + ProcessResponseFn: func(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { + h.res = resp return nil }, ServerAPI: h.serverAPI, diff --git a/x/mongo/driver/operation/insert.go b/x/mongo/driver/operation/insert.go index f64c586c83..b48e2c85f3 100644 --- a/x/mongo/driver/operation/insert.go +++ b/x/mongo/driver/operation/insert.go @@ -79,8 +79,8 @@ func NewInsert(documents ...bsoncore.Document) *Insert { // Result returns the result of executing this operation. func (i *Insert) Result() InsertResult { return i.result } -func (i *Insert) processResponse(info driver.ResponseInfo) error { - ir, err := buildInsertResult(info.ServerResponse) +func (i *Insert) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { + ir, err := buildInsertResult(resp) i.result.N += ir.N return err } diff --git a/x/mongo/driver/operation/list_collections.go b/x/mongo/driver/operation/list_collections.go index b746251dbc..3f9b55a6c3 100644 --- a/x/mongo/driver/operation/list_collections.go +++ b/x/mongo/driver/operation/list_collections.go @@ -55,9 +55,12 @@ func (lc *ListCollections) Result(opts driver.CursorOptions) (*driver.BatchCurso return driver.NewBatchCursor(lc.result, lc.session, lc.clock, opts) } -func (lc *ListCollections) processResponse(info driver.ResponseInfo) error { - var err error - lc.result, err = driver.NewCursorResponse(info) +func (lc *ListCollections) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + curDoc, err := driver.ExtractCursorDocument(resp) + if err != nil { + return err + } + lc.result, err = driver.NewCursorResponse(curDoc, info) return err } diff --git a/x/mongo/driver/operation/listDatabases.go b/x/mongo/driver/operation/list_databases.go similarity index 98% rename from x/mongo/driver/operation/listDatabases.go rename to x/mongo/driver/operation/list_databases.go index db1ba560bd..7273a26e53 100644 --- a/x/mongo/driver/operation/listDatabases.go +++ b/x/mongo/driver/operation/list_databases.go @@ -135,10 +135,10 @@ func NewListDatabases(filter bsoncore.Document) *ListDatabases { // Result returns the result of executing this operation. func (ld *ListDatabases) Result() ListDatabasesResult { return ld.result } -func (ld *ListDatabases) processResponse(info driver.ResponseInfo) error { +func (ld *ListDatabases) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - ld.result, err = buildListDatabasesResult(info.ServerResponse) + ld.result, err = buildListDatabasesResult(resp) return err } diff --git a/x/mongo/driver/operation/list_indexes.go b/x/mongo/driver/operation/list_indexes.go index df5a90acf8..d7642a19e1 100644 --- a/x/mongo/driver/operation/list_indexes.go +++ b/x/mongo/driver/operation/list_indexes.go @@ -53,10 +53,12 @@ func (li *ListIndexes) Result(opts driver.CursorOptions) (*driver.BatchCursor, e return driver.NewBatchCursor(li.result, clientSession, clock, opts) } -func (li *ListIndexes) processResponse(info driver.ResponseInfo) error { - var err error - - li.result, err = driver.NewCursorResponse(info) +func (li *ListIndexes) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + curDoc, err := driver.ExtractCursorDocument(resp) + if err != nil { + return err + } + li.result, err = driver.NewCursorResponse(curDoc, info) return err } diff --git a/x/mongo/driver/operation/update.go b/x/mongo/driver/operation/update.go index 612a5c1f9d..186f946313 100644 --- a/x/mongo/driver/operation/update.go +++ b/x/mongo/driver/operation/update.go @@ -123,8 +123,8 @@ func NewUpdate(updates ...bsoncore.Document) *Update { // Result returns the result of executing this operation. func (u *Update) Result() UpdateResult { return u.result } -func (u *Update) processResponse(info driver.ResponseInfo) error { - ur, err := buildUpdateResult(info.ServerResponse) +func (u *Update) processResponse(_ context.Context, resp bsoncore.Document, info driver.ResponseInfo) error { + ur, err := buildUpdateResult(resp) u.result.N += ur.N u.result.NModified += ur.NModified diff --git a/x/mongo/driver/operation/update_search_index.go b/x/mongo/driver/operation/update_search_index.go index 943de7c61e..81e3513fac 100644 --- a/x/mongo/driver/operation/update_search_index.go +++ b/x/mongo/driver/operation/update_search_index.go @@ -71,9 +71,9 @@ func NewUpdateSearchIndex(index string, definition bsoncore.Document) *UpdateSea // Result returns the result of executing this operation. func (usi *UpdateSearchIndex) Result() UpdateSearchIndexResult { return usi.result } -func (usi *UpdateSearchIndex) processResponse(info driver.ResponseInfo) error { +func (usi *UpdateSearchIndex) processResponse(_ context.Context, resp bsoncore.Document, _ driver.ResponseInfo) error { var err error - usi.result, err = buildUpdateSearchIndexResult(info.ServerResponse) + usi.result, err = buildUpdateSearchIndexResult(resp) return err } diff --git a/x/mongo/driver/operation_exhaust.go b/x/mongo/driver/operation_exhaust.go index a05a37c138..a643040d7c 100644 --- a/x/mongo/driver/operation_exhaust.go +++ b/x/mongo/driver/operation_exhaust.go @@ -27,10 +27,9 @@ func (op Operation) ExecuteExhaust(ctx context.Context, conn *mnet.Connection) e if op.ProcessResponseFn != nil { // Server, ConnectionDescription, and CurrentIndex are unused in this mode. info := ResponseInfo{ - ServerResponse: res, - Connection: conn, + Connection: conn, } - if err = op.ProcessResponseFn(info); err != nil { + if err = op.ProcessResponseFn(ctx, res, info); err != nil { return err } } diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index e084e9a024..cd89b707b4 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -100,8 +100,6 @@ type Client struct { RetryingCommit bool Committing bool Aborting bool - RetryWrite bool - RetryRead bool Snapshot bool // options for the current transaction diff --git a/x/mongo/driver/wiremessage/wiremessage.go b/x/mongo/driver/wiremessage/wiremessage.go index dd16cb7be0..9330499242 100644 --- a/x/mongo/driver/wiremessage/wiremessage.go +++ b/x/mongo/driver/wiremessage/wiremessage.go @@ -16,6 +16,7 @@ package wiremessage import ( "bytes" "encoding/binary" + "strconv" "strings" "sync/atomic" @@ -422,6 +423,39 @@ func ReadMsgSectionRawDocumentSequence(src []byte) (identifier string, data []by return identifier, rem, rest, true } +// DocumentSequenceToArray converts a document sequence in byte slice to an array. +func DocumentSequenceToArray(src []byte) (dst, rem []byte, ok bool) { + stype, rem, ok := ReadMsgSectionType(src) + if !ok || stype != DocumentSequence { + return nil, src, false + } + var identifier string + var ret []byte + identifier, rem, ret, ok = ReadMsgSectionRawDocumentSequence(rem) + if !ok { + return nil, src, false + } + + aidx, dst := bsoncore.AppendArrayElementStart(nil, identifier) + i := 0 + for { + var doc bsoncore.Document + doc, rem, ok = bsoncore.ReadDocument(rem) + if !ok { + break + } + dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(i), doc) + i++ + } + if len(rem) > 0 { + return nil, src, false + } + + dst, _ = bsoncore.AppendArrayEnd(dst, aidx) + + return dst, ret, true +} + // ReadMsgChecksum reads a checksum from src. func ReadMsgChecksum(src []byte) (checksum uint32, rem []byte, ok bool) { i32, rem, ok := readi32(src)