diff --git a/changestreams.go b/changestreams.go index cb5dda2f5..75cd5b1d4 100644 --- a/changestreams.go +++ b/changestreams.go @@ -1,10 +1,22 @@ package mgo +import ( + "fmt" + "reflect" + "sync" + + "gopkg.in/mgo.v2/bson" +) + type ChangeStream struct { iter *Iter options ChangeStreamOptions pipeline interface{} + resumeToken *bson.Raw + collection *Collection readPreference *ReadPreference + err error + m sync.Mutex } type ChangeStreamOptions struct { @@ -27,6 +39,74 @@ type ChangeStreamOptions struct { Collation *Collation } +// Next retrieves the next document from the change stream, blocking if necessary. +// Next returns true if a document was successfully unmarshalled into result, +// and false if an error occured. When Next returns false, the Err method should +// be called to check what error occurred during iteration. +// +// For example: +// +// pipeline := []bson.M{} +// +// changeStream := collection.Watch(pipeline, ChangeStreamOptions{}) +// for changeStream.Next(&changeDoc) { +// fmt.Printf("Change: %v\n", changeDoc) +// } +// +// if err := changeStream.Close(); err != nil { +// return err +// } +// +// If the pipeline used removes the _id field from the result, Next will error +// because the _id field is needed to resume iteration when an error occurs. +// +func (changeStream *ChangeStream) Next(result interface{}) bool { + // the err field is being constantly overwritten and we don't want the user to + // attempt to read it at this point so we lock. + changeStream.m.Lock() + + defer changeStream.m.Unlock() + + // if we are in a state of error, then don't continue. + if changeStream.err != nil { + return false + } + + var err error + + // attempt to fetch the change stream result. + err = changeStream.fetchResultSet(result) + if err == nil { + return true + } + + // check if the error is resumable + if !isResumableError(err) { + // error is not resumable, give up and return it to the user. + changeStream.err = err + return false + } + + // try to resume. + err = changeStream.resume() + if err != nil { + // we've not been able to successfully resume and should only try once, + // so we give up. + changeStream.err = err + return false + } + + // we've successfully resumed the changestream. + // try to fetch the next result. + err = changeStream.fetchResultSet(result) + if err != nil { + changeStream.err = err + return false + } + + return true +} + func constructChangeStreamPipeline(pipeline interface{}, options ChangeStreamOptions) interface{} { pipelinev := reflect.ValueOf(pipeline) @@ -38,21 +118,21 @@ func constructChangeStreamPipeline(pipeline interface{}, // construct the options to be used by the change notification // pipeline stage. - changeNotificationStageOptions := bson.M{} + changeStreamStageOptions := bson.M{} if options.FullDocument != "" { - changeNotificationStageOptions["fullDocument"] = options.FullDocument + changeStreamStageOptions["fullDocument"] = options.FullDocument } if options.ResumeAfter != nil { - changeNotificationStageOptions["resumeAfter"] = options.ResumeAfter + changeStreamStageOptions["resumeAfter"] = options.ResumeAfter } - changeNotificationStage := bson.M{"$changeNotification": changeNotificationStageOptions} + changeStreamStage := bson.M{"$changeStream": changeStreamStageOptions} pipeOfInterfaces := make([]interface{}, pipelinev.Len()+1) // insert the change notification pipeline stage at the beginning of the // aggregation. - pipeOfInterfaces[0] = changeNotificationStage + pipeOfInterfaces[0] = changeStreamStage // convert the passed in slice to a slice of interfaces. for i := 0; i < pipelinev.Len(); i++ { @@ -61,3 +141,102 @@ func constructChangeStreamPipeline(pipeline interface{}, var pipelineAsInterface interface{} = pipeOfInterfaces return pipelineAsInterface } + +func (changeStream *ChangeStream) resume() error { + // copy the information for the new socket. + + // Copy() destroys the sockets currently associated with this session + // so future uses will acquire a new socket against the newly selected DB. + newSession := changeStream.iter.session.Copy() + + // fetch the cursor from the iterator and use it to run a killCursors + // on the connection. + cursorId := changeStream.iter.op.cursorId + err := runKillCursorsOnSession(newSession, cursorId) + if err != nil { + return err + } + + // change out the old connection to the database with the new connection. + changeStream.collection.Database.Session = newSession + + // make a new pipeline containing the resume token. + changeStreamPipeline := constructChangeStreamPipeline(changeStream.pipeline, changeStream.options) + + // generate the new iterator with the new connection. + newPipe := changeStream.collection.Pipe(changeStreamPipeline) + changeStream.iter = newPipe.Iter() + changeStream.iter.isChangeStream = true + + return nil +} + +// fetchResumeToken unmarshals the _id field from the document, setting an error +// on the changeStream if it is unable to. +func (changeStream *ChangeStream) fetchResumeToken(rawResult *bson.Raw) error { + changeStreamResult := struct { + ResumeToken *bson.Raw `bson:"_id,omitempty"` + }{} + + err := rawResult.Unmarshal(&changeStreamResult) + if err != nil { + return err + } + + if changeStreamResult.ResumeToken == nil { + return fmt.Errorf("resume token missing from result") + } + + changeStream.resumeToken = changeStreamResult.ResumeToken + return nil +} + +func (changeStream *ChangeStream) fetchResultSet(result interface{}) error { + rawResult := bson.Raw{} + + // fetch the next set of documents from the cursor. + gotNext := changeStream.iter.Next(&rawResult) + + err := changeStream.iter.Err() + if err != nil { + return err + } + + if !gotNext && err == nil { + // If the iter.Err() method returns nil despite us not getting a next batch, + // it is becuase iter.Err() silences this case. + return ErrNotFound + } + + // grab the resumeToken from the results + if err := changeStream.fetchResumeToken(&rawResult); err != nil { + return err + } + + // put the raw results into the data structure the user provided. + if err := rawResult.Unmarshal(result); err != nil { + return err + } + return nil +} + +func isResumableError(err error) bool { + _, isQueryError := err.(*QueryError) + // if it is not a database error OR it is a database error, + // but the error is a notMaster error + return !isQueryError || isNotMasterError(err) +} + +func runKillCursorsOnSession(session *Session, cursorId int64) error { + socket, err := session.acquireSocket(true) + if err != nil { + return err + } + err = socket.Query(&killCursorsOp{[]int64{cursorId}}) + if err != nil { + return err + } + socket.Release() + + return nil +} diff --git a/session.go b/session.go index 108cdae88..d27dec3cc 100644 --- a/session.go +++ b/session.go @@ -138,7 +138,8 @@ type Iter struct { docsBeforeMore int timeout time.Duration timedout bool - findCmd bool + isFindCmd bool + isChangeStream bool } var ( @@ -993,6 +994,11 @@ func isAuthError(err error) bool { return ok && e.Code == 13 } +func isNotMasterError(err error) bool { + e, ok := err.(*QueryError) + return ok && strings.Contains(e.Message, "not master") +} + func (db *Database) runUserCmd(cmdName string, user *User) error { cmd := make(bson.D, 0, 16) cmd = append(cmd, bson.DocElem{cmdName, user.Username}) @@ -2382,7 +2388,7 @@ func (c *Collection) NewIter(session *Session, firstBatch []bson.Raw, cursorId i } if socket.ServerInfo().MaxWireVersion >= 4 && c.FullName != "admin.$cmd" { - iter.findCmd = true + iter.isFindCmd = true } iter.gotReply.L = &iter.m @@ -3550,7 +3556,7 @@ func (q *Query) Iter() *Iter { op.replyFunc = iter.op.replyFunc if prepareFindOp(socket, &op, limit) { - iter.findCmd = true + iter.isFindCmd = true } iter.server = socket.Server() @@ -3780,7 +3786,12 @@ func (iter *Iter) Next(result interface{}) bool { iter.m.Lock() iter.timedout = false timeout := time.Time{} + + // check should we expect more data. for iter.err == nil && iter.docData.Len() == 0 && (iter.docsToReceive > 0 || iter.op.cursorId != 0) { + // we should expect more data. + + // If we have yet to receive data, increment the timer until we timeout. if iter.docsToReceive == 0 { if iter.timeout >= 0 { if timeout.IsZero() { @@ -3792,6 +3803,7 @@ func (iter *Iter) Next(result interface{}) bool { return false } } + // run a getmore to fetch more data. iter.getMore() if iter.err != nil { break @@ -3800,6 +3812,7 @@ func (iter *Iter) Next(result interface{}) bool { iter.gotReply.Wait() } + // We have data from the getMore. // Exhaust available data before reporting any errors. if docData, ok := iter.docData.Pop().([]byte); ok { close := false @@ -3815,6 +3828,7 @@ func (iter *Iter) Next(result interface{}) bool { } } if iter.op.cursorId != 0 && iter.err == nil { + // we still have a live cursor and currently expect data. iter.docsBeforeMore-- if iter.docsBeforeMore == -1 { iter.getMore() @@ -4004,7 +4018,7 @@ func (iter *Iter) getMore() { } } var op interface{} - if iter.findCmd { + if iter.isFindCmd || iter.isChangeStream { op = iter.getMoreCmd() } else { op = &iter.op @@ -4608,7 +4622,7 @@ func (iter *Iter) replyFunc() replyFunc { } else { iter.err = ErrNotFound } - } else if iter.findCmd { + } else if iter.isFindCmd { debugf("Iter %p received reply document %d/%d (cursor=%d)", iter, docNum+1, int(op.replyDocs), op.cursorId) var findReply struct { Ok bool @@ -4620,7 +4634,7 @@ func (iter *Iter) replyFunc() replyFunc { iter.err = err } else if !findReply.Ok && findReply.Errmsg != "" { iter.err = &QueryError{Code: findReply.Code, Message: findReply.Errmsg} - } else if len(findReply.Cursor.FirstBatch) == 0 && len(findReply.Cursor.NextBatch) == 0 { + } else if !iter.isChangeStream && len(findReply.Cursor.FirstBatch) == 0 && len(findReply.Cursor.NextBatch) == 0 { iter.err = ErrNotFound } else { batch := findReply.Cursor.FirstBatch