diff --git a/README.md b/README.md index 6c87fa905..7531fe4e6 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ A [sub-package](https://godoc.org/github.com/globalsign/mgo/bson) that implement * Supports dropping all indexes on a collection ([details](https://github.com/globalsign/mgo/pull/25)) * Annotates log entries/profiler output with optional appName on 3.4+ ([details](https://github.com/globalsign/mgo/pull/28)) * Support for read-only [views](https://docs.mongodb.com/manual/core/views/) in 3.4+ ([details](https://github.com/globalsign/mgo/pull/33)) -* Support for [collations](https://docs.mongodb.com/manual/reference/collation/) in 3.4+ ([details](https://github.com/globalsign/mgo/pull/37)) +* Support for [collations](https://docs.mongodb.com/manual/reference/collation/) in 3.4+ ([details](https://github.com/globalsign/mgo/pull/37), [more](https://github.com/globalsign/mgo/pull/166)) * Provide BSON constants for convenience/sanity ([details](https://github.com/globalsign/mgo/pull/41)) * Consistently unmarshal time.Time values as UTC ([details](https://github.com/globalsign/mgo/pull/42)) * Enforces best practise coding guidelines ([details](https://github.com/globalsign/mgo/pull/44)) @@ -49,6 +49,15 @@ A [sub-package](https://godoc.org/github.com/globalsign/mgo/bson) that implement * Add BSON stream encoders ([details](https://github.com/globalsign/mgo/pull/127)) * Add integer map key support in the BSON encoder ([details](https://github.com/globalsign/mgo/pull/140)) * Support aggregation [collations](https://docs.mongodb.com/manual/reference/collation/) ([details](https://github.com/globalsign/mgo/pull/144)) +* Support encoding of inline struct references ([details](https://github.com/globalsign/mgo/pull/146)) +* Improved windows test harness ([details](https://github.com/globalsign/mgo/pull/158)) +* Improved type and nil handling in the BSON codec ([details](https://github.com/globalsign/mgo/pull/147/files), [more](https://github.com/globalsign/mgo/pull/181)) +* Separated network read/write timeouts ([details](https://github.com/globalsign/mgo/pull/161)) +* Expanded dial string configuration options ([details](https://github.com/globalsign/mgo/pull/162)) +* Implement MongoTimestamp ([details](https://github.com/globalsign/mgo/pull/171)) +* Support setting `writeConcern` for `findAndModify` operations ([details](https://github.com/globalsign/mgo/pull/185)) +* Add `ssl` to the dial string options ([details](https://github.com/globalsign/mgo/pull/184)) + --- @@ -59,23 +68,32 @@ A [sub-package](https://godoc.org/github.com/globalsign/mgo/bson) that implement * @BenLubar * @carldunham * @carter2000 +* @cedric-cordenier * @cezarsa +* @DaytonG +* @ddspog * @drichelson * @dvic * @eaglerayp * @feliixx * @fmpwizard * @gazoon +* @gedge * @gnawux * @idy * @jameinel +* @jefferickson * @johnlawsharrison * @KJTsanaktsidis +* @larrycinnabar * @mapete94 * @maxnoel * @mcspring +* @Mei-Zhao * @peterdeka * @Reenjii +* @roobre * @smoya * @steve-gray +* @tbruyelle * @wgallagher diff --git a/bson/bson.go b/bson/bson.go index 31beab191..eb87ef620 100644 --- a/bson/bson.go +++ b/bson/bson.go @@ -42,6 +42,7 @@ import ( "errors" "fmt" "io" + "math" "os" "reflect" "runtime" @@ -426,6 +427,36 @@ func Now() time.Time { // strange reason has its own datatype defined in BSON. type MongoTimestamp int64 +// Time returns the time part of ts which is stored with second precision. +func (ts MongoTimestamp) Time() time.Time { + return time.Unix(int64(uint64(ts)>>32), 0) +} + +// Counter returns the counter part of ts. +func (ts MongoTimestamp) Counter() uint32 { + return uint32(ts) +} + +// NewMongoTimestamp creates a timestamp using the given +// date `t` (with second precision) and counter `c` (unique for `t`). +// +// Returns an error if time `t` is not between 1970-01-01T00:00:00Z +// and 2106-02-07T06:28:15Z (inclusive). +// +// Note that two MongoTimestamps should never have the same (time, counter) combination: +// the caller must ensure the counter `c` is increased if creating multiple MongoTimestamp +// values for the same time `t` (ignoring fractions of seconds). +func NewMongoTimestamp(t time.Time, c uint32) (MongoTimestamp, error) { + u := t.Unix() + if u < 0 || u > math.MaxUint32 { + return -1, errors.New("invalid value for time") + } + + i := int64(u<<32 | int64(c)) + + return MongoTimestamp(i), nil +} + type orderKey int64 // MaxKey is a special value that compares higher than all other possible BSON @@ -746,6 +777,14 @@ func getStructInfo(st reflect.Type) (*structInfo, error) { return nil, errors.New("Option ,inline needs a map with string keys in struct " + st.String()) } inlineMap = info.Num + case reflect.Ptr: + // allow only pointer to struct + if kind := field.Type.Elem().Kind(); kind != reflect.Struct { + return nil, errors.New("Option ,inline allows a pointer only to a struct, was given pointer to " + kind.String()) + } + + field.Type = field.Type.Elem() + fallthrough case reflect.Struct: sinfo, err := getStructInfo(field.Type) if err != nil { @@ -765,7 +804,7 @@ func getStructInfo(st reflect.Type) (*structInfo, error) { fieldsList = append(fieldsList, finfo) } default: - panic("Option ,inline needs a struct value or map field") + panic("Option ,inline needs a struct value or a pointer to a struct or map field") } continue } diff --git a/bson/bson_test.go b/bson/bson_test.go index 406ede6ae..60dcde1ff 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -32,6 +32,8 @@ import ( "encoding/json" "encoding/xml" "errors" + "fmt" + "math/rand" "net/url" "reflect" "strings" @@ -271,6 +273,42 @@ func (s *S) TestMarshalBuffer(c *C) { c.Assert(data, DeepEquals, buf[:len(data)]) } +func (s *S) TestPtrInline(c *C) { + cases := []struct { + In interface{} + Out bson.M + }{ + { + In: inlinePtrStruct{A: 1, MStruct: &MStruct{M: 3}}, + Out: bson.M{"a": 1, "m": 3}, + }, + { // go deeper + In: inlinePtrPtrStruct{B: 10, inlinePtrStruct: &inlinePtrStruct{A: 20, MStruct: &MStruct{M: 30}}}, + Out: bson.M{"b": 10, "a": 20, "m": 30}, + }, + { + // nil embed struct + In: &inlinePtrStruct{A: 3}, + Out: bson.M{"a": 3}, + }, + { + // nil embed struct + In: &inlinePtrPtrStruct{B: 5}, + Out: bson.M{"b": 5}, + }, + } + + for _, cs := range cases { + data, err := bson.Marshal(cs.In) + c.Assert(err, IsNil) + var dataBSON bson.M + err = bson.Unmarshal(data, &dataBSON) + c.Assert(err, IsNil) + + c.Assert(dataBSON, DeepEquals, cs.Out) + } +} + // -------------------------------------------------------------------------- // Some one way marshaling operations which would unmarshal differently. @@ -713,8 +751,6 @@ var marshalErrorItems = []testItemType{ "Attempted to marshal empty Raw document"}, {bson.M{"w": bson.Raw{Kind: 0x3, Data: []byte{}}}, "Attempted to marshal empty Raw document"}, - {&inlineCantPtr{&struct{ A, B int }{1, 2}}, - "Option ,inline needs a struct value or map field"}, {&inlineDupName{1, struct{ A, B int }{2, 3}}, "Duplicated key 'a' in struct bson_test.inlineDupName"}, {&inlineDupMap{}, @@ -1171,8 +1207,19 @@ type inlineBadKeyMap struct { M map[int]int `bson:",inline"` } type inlineUnexported struct { - M map[string]interface{} `bson:",inline"` - unexported `bson:",inline"` + M map[string]interface{} `bson:",inline"` + unexported `bson:",inline"` +} +type MStruct struct { + M int `bson:"m,omitempty"` +} +type inlinePtrStruct struct { + A int + *MStruct `bson:",inline"` +} +type inlinePtrPtrStruct struct { + B int + *inlinePtrStruct `bson:",inline"` } type unexported struct { A int @@ -1229,11 +1276,11 @@ func (s ifaceSlice) GetBSON() (interface{}, error) { type ( MyString string - MyBytes []byte - MyBool bool - MyD []bson.DocElem - MyRawD []bson.RawDocElem - MyM map[string]interface{} + MyBytes []byte + MyBool bool + MyD []bson.DocElem + MyRawD []bson.RawDocElem + MyM map[string]interface{} ) var ( @@ -1888,3 +1935,105 @@ func (s *S) BenchmarkNewObjectId(c *C) { bson.NewObjectId() } } + +func (s *S) TestMarshalRespectNil(c *C) { + type T struct { + Slice []int + SlicePtr *[]int + Ptr *int + Map map[string]interface{} + MapPtr *map[string]interface{} + } + + bson.SetRespectNilValues(true) + defer bson.SetRespectNilValues(false) + + testStruct1 := T{} + + c.Assert(testStruct1.Slice, IsNil) + c.Assert(testStruct1.SlicePtr, IsNil) + c.Assert(testStruct1.Map, IsNil) + c.Assert(testStruct1.MapPtr, IsNil) + c.Assert(testStruct1.Ptr, IsNil) + + b, _ := bson.Marshal(testStruct1) + + testStruct2 := T{} + + bson.Unmarshal(b, &testStruct2) + + c.Assert(testStruct2.Slice, IsNil) + c.Assert(testStruct2.SlicePtr, IsNil) + c.Assert(testStruct2.Map, IsNil) + c.Assert(testStruct2.MapPtr, IsNil) + c.Assert(testStruct2.Ptr, IsNil) + + testStruct1 = T{ + Slice: []int{}, + SlicePtr: &[]int{}, + Map: map[string]interface{}{}, + MapPtr: &map[string]interface{}{}, + } + + c.Assert(testStruct1.Slice, NotNil) + c.Assert(testStruct1.SlicePtr, NotNil) + c.Assert(testStruct1.Map, NotNil) + c.Assert(testStruct1.MapPtr, NotNil) + + b, _ = bson.Marshal(testStruct1) + + testStruct2 = T{} + + bson.Unmarshal(b, &testStruct2) + + c.Assert(testStruct2.Slice, NotNil) + c.Assert(testStruct2.SlicePtr, NotNil) + c.Assert(testStruct2.Map, NotNil) + c.Assert(testStruct2.MapPtr, NotNil) +} + +func (s *S) TestMongoTimestampTime(c *C) { + t := time.Now() + ts, err := bson.NewMongoTimestamp(t, 123) + c.Assert(err, IsNil) + c.Assert(ts.Time().Unix(), Equals, t.Unix()) +} + +func (s *S) TestMongoTimestampCounter(c *C) { + rnd := rand.Uint32() + ts, err := bson.NewMongoTimestamp(time.Now(), rnd) + c.Assert(err, IsNil) + c.Assert(ts.Counter(), Equals, rnd) +} + +func (s *S) TestMongoTimestampError(c *C) { + t := time.Date(1969, time.December, 31, 23, 59, 59, 999, time.UTC) + ts, err := bson.NewMongoTimestamp(t, 321) + c.Assert(int64(ts), Equals, int64(-1)) + c.Assert(err, ErrorMatches, "invalid value for time") +} + +func ExampleNewMongoTimestamp() { + + var counter uint32 = 1 + var t time.Time + + for i := 1; i <= 3; i++ { + + if c := time.Now(); t.Unix() == c.Unix() { + counter++ + } else { + t = c + counter = 1 + } + + ts, err := bson.NewMongoTimestamp(t, counter) + if err != nil { + fmt.Printf("NewMongoTimestamp error: %v", err) + } else { + fmt.Printf("NewMongoTimestamp encoded timestamp: %d\n", ts) + } + + time.Sleep(500 * time.Millisecond) + } +} diff --git a/bson/compatibility.go b/bson/compatibility.go index 6afecf53c..66efd465f 100644 --- a/bson/compatibility.go +++ b/bson/compatibility.go @@ -1,7 +1,8 @@ package bson -// Current state of the JSON tag fallback option. +// Current state of the JSON tag fallback option. var useJSONTagFallback = false +var useRespectNilValues = false // SetJSONTagFallback enables or disables the JSON-tag fallback for structure tagging. When this is enabled, structures // without BSON tags on a field will fall-back to using the JSON tag (if present). @@ -14,3 +15,15 @@ func SetJSONTagFallback(state bool) { func JSONTagFallbackState() bool { return useJSONTagFallback } + +// SetRespectNilValues enables or disables serializing nil slices or maps to `null` values. +// In other words it enables `encoding/json` compatible behaviour. +func SetRespectNilValues(state bool) { + useRespectNilValues = state +} + +// RespectNilValuesState returns the current status of the JSON nil slices and maps fallback compatibility option. +// See SetRespectNilValues for more information. +func RespectNilValuesState() bool { + return useRespectNilValues +} diff --git a/bson/encode.go b/bson/encode.go index 7e0b84d77..d0c6b2a85 100644 --- a/bson/encode.go +++ b/bson/encode.go @@ -229,15 +229,48 @@ func (e *encoder) addStruct(v reflect.Value) { if info.Inline == nil { value = v.Field(info.Num) } else { - value = v.FieldByIndex(info.Inline) + // as pointers to struct are allowed here, + // there is no guarantee that pointer won't be nil. + // + // It is expected allowed behaviour + // so info.Inline MAY consist index to a nil pointer + // and that is why we safely call v.FieldByIndex and just continue on panic + field, errField := safeFieldByIndex(v, info.Inline) + if errField != nil { + continue + } + + value = field } if info.OmitEmpty && isZero(value) { continue } + if useRespectNilValues && + (value.Kind() == reflect.Slice || value.Kind() == reflect.Map) && + value.IsNil() { + e.addElem(info.Key, reflect.ValueOf(nil), info.MinSize) + continue + } e.addElem(info.Key, value, info.MinSize) } } +func safeFieldByIndex(v reflect.Value, index []int) (result reflect.Value, err error) { + defer func() { + if recovered := recover(); recovered != nil { + switch r := recovered.(type) { + case string: + err = fmt.Errorf("%s", r) + case error: + err = r + } + } + }() + + result = v.FieldByIndex(index) + return +} + func isZero(v reflect.Value) bool { switch v.Kind() { case reflect.String: diff --git a/cluster.go b/cluster.go index 4e54c5d81..ff431cac5 100644 --- a/cluster.go +++ b/cluster.go @@ -48,34 +48,26 @@ import ( type mongoCluster struct { sync.RWMutex - serverSynced sync.Cond - userSeeds []string - dynaSeeds []string - servers mongoServers - masters mongoServers - references int - syncing bool - direct bool - failFast bool - syncCount uint - setName string - cachedIndex map[string]bool - sync chan bool - dial dialer - appName string - minPoolSize int - maxIdleTimeMS int + serverSynced sync.Cond + userSeeds []string + dynaSeeds []string + servers mongoServers + masters mongoServers + references int + syncing bool + syncCount uint + cachedIndex map[string]bool + sync chan bool + dial dialer + dialInfo *DialInfo } -func newCluster(userSeeds []string, direct, failFast bool, dial dialer, setName string, appName string) *mongoCluster { +func newCluster(userSeeds []string, info *DialInfo) *mongoCluster { cluster := &mongoCluster{ userSeeds: userSeeds, references: 1, - direct: direct, - failFast: failFast, - dial: dial, - setName: setName, - appName: appName, + dial: dialer{info.Dial, info.DialServer}, + dialInfo: info, } cluster.serverSynced.L = cluster.RWMutex.RLocker() cluster.sync = make(chan bool, 1) @@ -147,7 +139,7 @@ type isMasterResult struct { func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResult) error { // Monotonic let's it talk to a slave and still hold the socket. - session := newSession(Monotonic, cluster, 10*time.Second) + session := newSession(Monotonic, cluster, cluster.dialInfo) session.setSocket(socket) var cmd = bson.D{{Name: "isMaster", Value: 1}} @@ -171,8 +163,8 @@ func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResul } // Include the application name if set - if cluster.appName != "" { - meta["application"] = bson.M{"name": cluster.appName} + if cluster.dialInfo.AppName != "" { + meta["application"] = bson.M{"name": cluster.dialInfo.AppName} } cmd = append(cmd, bson.DocElem{ @@ -190,19 +182,7 @@ type possibleTimeout interface { Timeout() bool } -var syncSocketTimeout = 5 * time.Second - func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerInfo, hosts []string, err error) { - var syncTimeout time.Duration - if raceDetector { - // This variable is only ever touched by tests. - globalMutex.Lock() - syncTimeout = syncSocketTimeout - globalMutex.Unlock() - } else { - syncTimeout = syncSocketTimeout - } - addr := server.Addr log("SYNC Processing ", addr, "...") @@ -210,7 +190,7 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI var result isMasterResult var tryerr error for retry := 0; ; retry++ { - if retry == 3 || retry == 1 && cluster.failFast { + if retry == 3 || retry == 1 && cluster.dialInfo.FailFast { return nil, nil, tryerr } if retry > 0 { @@ -222,16 +202,22 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI time.Sleep(syncShortDelay) } - // It's not clear what would be a good timeout here. Is it - // better to wait longer or to retry? - socket, _, err := server.AcquireSocket(0, syncTimeout) + // Don't ever hit the pool limit for syncing + config := cluster.dialInfo.Copy() + config.PoolLimit = 0 + + socket, _, err := server.AcquireSocket(config) if err != nil { tryerr = err logf("SYNC Failed to get socket to %s: %v", addr, err) continue } err = cluster.isMaster(socket, &result) + + // Restore the correct dial config before returning it to the pool + socket.dialInfo = cluster.dialInfo socket.Release() + if err != nil { tryerr = err logf("SYNC Command 'ismaster' to %s failed: %v", addr, err) @@ -241,9 +227,9 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI break } - if cluster.setName != "" && result.SetName != cluster.setName { - logf("SYNC Server %s is not a member of replica set %q", addr, cluster.setName) - return nil, nil, fmt.Errorf("server %s is not a member of replica set %q", addr, cluster.setName) + if cluster.dialInfo.ReplicaSetName != "" && result.SetName != cluster.dialInfo.ReplicaSetName { + logf("SYNC Server %s is not a member of replica set %q", addr, cluster.dialInfo.ReplicaSetName) + return nil, nil, fmt.Errorf("server %s is not a member of replica set %q", addr, cluster.dialInfo.ReplicaSetName) } if result.IsMaster { @@ -255,7 +241,7 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI } } else if result.Secondary { debugf("SYNC %s is a slave.", addr) - } else if cluster.direct { + } else if cluster.dialInfo.Direct { logf("SYNC %s in unknown state. Pretending it's a slave due to direct connection.", addr) } else { logf("SYNC %s is neither a master nor a slave.", addr) @@ -386,7 +372,7 @@ func (cluster *mongoCluster) syncServersLoop() { break } cluster.references++ // Keep alive while syncing. - direct := cluster.direct + direct := cluster.dialInfo.Direct cluster.Unlock() cluster.syncServersIteration(direct) @@ -401,7 +387,7 @@ func (cluster *mongoCluster) syncServersLoop() { // Hold off before allowing another sync. No point in // burning CPU looking for down servers. - if !cluster.failFast { + if !cluster.dialInfo.FailFast { time.Sleep(syncShortDelay) } @@ -439,13 +425,11 @@ func (cluster *mongoCluster) syncServersLoop() { func (cluster *mongoCluster) server(addr string, tcpaddr *net.TCPAddr) *mongoServer { cluster.RLock() server := cluster.servers.Search(tcpaddr.String()) - minPoolSize := cluster.minPoolSize - maxIdleTimeMS := cluster.maxIdleTimeMS cluster.RUnlock() if server != nil { return server } - return newServer(addr, tcpaddr, cluster.sync, cluster.dial, minPoolSize, maxIdleTimeMS) + return newServer(addr, tcpaddr, cluster.sync, cluster.dial, cluster.dialInfo) } func resolveAddr(addr string) (*net.TCPAddr, error) { @@ -614,19 +598,10 @@ func (cluster *mongoCluster) syncServersIteration(direct bool) { cluster.Unlock() } -// AcquireSocket returns a socket to a server in the cluster. If slaveOk is -// true, it will attempt to return a socket to a slave server. If it is -// false, the socket will necessarily be to a master server. -func (cluster *mongoCluster) AcquireSocket(mode Mode, slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D, poolLimit int) (s *mongoSocket, err error) { - return cluster.AcquireSocketWithPoolTimeout(mode, slaveOk, syncTimeout, socketTimeout, serverTags, poolLimit, 0) -} - // AcquireSocketWithPoolTimeout returns a socket to a server in the cluster. If slaveOk is // true, it will attempt to return a socket to a slave server. If it is // false, the socket will necessarily be to a master server. -func (cluster *mongoCluster) AcquireSocketWithPoolTimeout( - mode Mode, slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D, poolLimit int, poolTimeout time.Duration, -) (s *mongoSocket, err error) { +func (cluster *mongoCluster) AcquireSocketWithPoolTimeout(mode Mode, slaveOk bool, syncTimeout time.Duration, serverTags []bson.D, info *DialInfo) (s *mongoSocket, err error) { var started time.Time var syncCount uint for { @@ -645,7 +620,7 @@ func (cluster *mongoCluster) AcquireSocketWithPoolTimeout( // Initialize after fast path above. started = time.Now() syncCount = cluster.syncCount - } else if syncTimeout != 0 && started.Before(time.Now().Add(-syncTimeout)) || cluster.failFast && cluster.syncCount != syncCount { + } else if syncTimeout != 0 && started.Before(time.Now().Add(-syncTimeout)) || cluster.dialInfo.FailFast && cluster.syncCount != syncCount { cluster.RUnlock() return nil, errors.New("no reachable servers") } @@ -670,7 +645,7 @@ func (cluster *mongoCluster) AcquireSocketWithPoolTimeout( continue } - s, abended, err := server.AcquireSocketWithBlocking(poolLimit, socketTimeout, poolTimeout) + s, abended, err := server.AcquireSocketWithBlocking(info) if err == errPoolTimeout { // No need to remove servers from the topology if acquiring a socket fails for this reason. return nil, err diff --git a/cluster_test.go b/cluster_test.go index be11dc1a7..de99d414d 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1055,8 +1055,6 @@ func (s *S) TestSocketTimeoutOnDial(c *C) { timeout := 1 * time.Second - defer mgo.HackSyncSocketTimeout(timeout)() - s.Freeze("localhost:40001") started := time.Now() diff --git a/dbtest/dbserver.go b/dbtest/dbserver.go index 2fadaf764..3840827f9 100644 --- a/dbtest/dbserver.go +++ b/dbtest/dbserver.go @@ -6,6 +6,7 @@ import ( "net" "os" "os/exec" + "runtime" "strconv" "time" @@ -70,7 +71,7 @@ func (dbs *DBServer) start() { err = dbs.server.Start() if err != nil { // print error to facilitate troubleshooting as the panic will be caught in a panic handler - fmt.Fprintf(os.Stderr, "mongod failed to start: %v\n",err) + fmt.Fprintf(os.Stderr, "mongod failed to start: %v\n", err) panic(err) } dbs.tomb.Go(dbs.monitor) @@ -113,7 +114,12 @@ func (dbs *DBServer) Stop() { } if dbs.server != nil { dbs.tomb.Kill(nil) - dbs.server.Process.Signal(os.Interrupt) + // Windows doesn't support Interrupt + if runtime.GOOS == "windows" { + dbs.server.Process.Signal(os.Kill) + } else { + dbs.server.Process.Signal(os.Interrupt) + } select { case <-dbs.tomb.Dead(): case <-time.After(5 * time.Second): diff --git a/example_test.go b/example_test.go index d176d5f5c..9775ba9e1 100644 --- a/example_test.go +++ b/example_test.go @@ -137,7 +137,21 @@ func ExampleSession_concurrency() { func ExampleDial_usingSSL() { // To connect via TLS/SSL (enforced for MongoDB Atlas for example) requires - // configuring the dialer to use a TLS connection: + // to set the ssl query param to true. + url := "mongodb://localhost:40003?ssl=true" + + session, err := Dial(url) + if err != nil { + panic(err) + } + + // Use session as normal + session.Close() +} + +func ExampleDial_tlsConfig() { + // You can define a custom tlsConfig, this one enables TLS, like if you have + // ssl=true in the connection string. url := "mongodb://localhost:40003" tlsConfig := &tls.Config{ diff --git a/export_test.go b/export_test.go index 998c7a2dd..1b7d7e941 100644 --- a/export_test.go +++ b/export_test.go @@ -19,20 +19,6 @@ func HackPingDelay(newDelay time.Duration) (restore func()) { return } -func HackSyncSocketTimeout(newTimeout time.Duration) (restore func()) { - globalMutex.Lock() - defer globalMutex.Unlock() - - oldTimeout := syncSocketTimeout - restore = func() { - globalMutex.Lock() - syncSocketTimeout = oldTimeout - globalMutex.Unlock() - } - syncSocketTimeout = newTimeout - return -} - func (s *Session) Cluster() *mongoCluster { return s.cluster() } diff --git a/internal/sasl/sasl.go b/internal/sasl/sasl.go index 25a537426..0b56f0b6f 100644 --- a/internal/sasl/sasl.go +++ b/internal/sasl/sasl.go @@ -127,6 +127,7 @@ func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, er if rc == C.SASL_CONTINUE { return clientData, false, nil } + return nil, false, saslError(rc, ss.conn, "cannot establish SASL session") } diff --git a/internal/scram/scram.go b/internal/scram/scram.go index d3ddd02fd..03c14daf7 100644 --- a/internal/scram/scram.go +++ b/internal/scram/scram.go @@ -91,7 +91,7 @@ func NewClient(newHash func() hash.Hash, user, pass string) *Client { // Out returns the data to be sent to the server in the current step. func (c *Client) Out() []byte { if c.out.Len() == 0 { - return nil + return []byte{} } return c.out.Bytes() } diff --git a/server.go b/server.go index f34624f74..6f51ca5e3 100644 --- a/server.go +++ b/server.go @@ -67,9 +67,8 @@ type mongoServer struct { pingCount uint32 closed bool abended bool - minPoolSize int - maxIdleTimeMS int poolWaiter *sync.Cond + dialInfo *DialInfo } type dialer struct { @@ -91,21 +90,20 @@ type mongoServerInfo struct { var defaultServerInfo mongoServerInfo -func newServer(addr string, tcpaddr *net.TCPAddr, syncChan chan bool, dial dialer, minPoolSize, maxIdleTimeMS int) *mongoServer { +func newServer(addr string, tcpaddr *net.TCPAddr, syncChan chan bool, dial dialer, info *DialInfo) *mongoServer { server := &mongoServer{ - Addr: addr, - ResolvedAddr: tcpaddr.String(), - tcpaddr: tcpaddr, - sync: syncChan, - dial: dial, - info: &defaultServerInfo, - pingValue: time.Hour, // Push it back before an actual ping. - minPoolSize: minPoolSize, - maxIdleTimeMS: maxIdleTimeMS, + Addr: addr, + ResolvedAddr: tcpaddr.String(), + tcpaddr: tcpaddr, + sync: syncChan, + dial: dial, + info: &defaultServerInfo, + pingValue: time.Hour, // Push it back before an actual ping. + dialInfo: info, } server.poolWaiter = sync.NewCond(server) go server.pinger(true) - if maxIdleTimeMS != 0 { + if info.MaxIdleTimeMS != 0 { go server.poolShrinker() } return server @@ -123,22 +121,18 @@ var errServerClosed = errors.New("server was closed") // If the poolLimit argument is greater than zero and the number of sockets in // use in this server is greater than the provided limit, errPoolLimit is // returned. -func (server *mongoServer) AcquireSocket(poolLimit int, timeout time.Duration) (socket *mongoSocket, abended bool, err error) { - return server.acquireSocketInternal(poolLimit, timeout, false, 0*time.Millisecond) +func (server *mongoServer) AcquireSocket(info *DialInfo) (socket *mongoSocket, abended bool, err error) { + return server.acquireSocketInternal(info, false) } // AcquireSocketWithBlocking wraps AcquireSocket, but if a socket is not available, it will _not_ // return errPoolLimit. Instead, it will block waiting for a socket to become available. If poolTimeout // should elapse before a socket is available, it will return errPoolTimeout. -func (server *mongoServer) AcquireSocketWithBlocking( - poolLimit int, socketTimeout time.Duration, poolTimeout time.Duration, -) (socket *mongoSocket, abended bool, err error) { - return server.acquireSocketInternal(poolLimit, socketTimeout, true, poolTimeout) +func (server *mongoServer) AcquireSocketWithBlocking(info *DialInfo) (socket *mongoSocket, abended bool, err error) { + return server.acquireSocketInternal(info, true) } -func (server *mongoServer) acquireSocketInternal( - poolLimit int, timeout time.Duration, shouldBlock bool, poolTimeout time.Duration, -) (socket *mongoSocket, abended bool, err error) { +func (server *mongoServer) acquireSocketInternal(info *DialInfo, shouldBlock bool) (socket *mongoSocket, abended bool, err error) { for { server.Lock() abended = server.abended @@ -146,7 +140,7 @@ func (server *mongoServer) acquireSocketInternal( server.Unlock() return nil, abended, errServerClosed } - if poolLimit > 0 { + if info.PoolLimit > 0 { if shouldBlock { // Beautiful. Golang conditions don't have a WaitWithTimeout, so I've implemented the timeout // with a wait + broadcast. The broadcast will cause the loop here to re-check the timeout, @@ -158,11 +152,11 @@ func (server *mongoServer) acquireSocketInternal( // https://github.com/golang/go/issues/16620, since the lock needs to be held in _this_ goroutine. waitDone := make(chan struct{}) timeoutHit := false - if poolTimeout > 0 { + if info.PoolTimeout > 0 { go func() { select { case <-waitDone: - case <-time.After(poolTimeout): + case <-time.After(info.PoolTimeout): // timeoutHit is part of the wait condition, so needs to be changed under mutex. server.Lock() defer server.Unlock() @@ -172,7 +166,7 @@ func (server *mongoServer) acquireSocketInternal( }() } timeSpentWaiting := time.Duration(0) - for len(server.liveSockets)-len(server.unusedSockets) >= poolLimit && !timeoutHit { + for len(server.liveSockets)-len(server.unusedSockets) >= info.PoolLimit && !timeoutHit { // We only count time spent in Wait(), and not time evaluating the entire loop, // so that in the happy non-blocking path where the condition above evaluates true // first time, we record a nice round zero wait time. @@ -191,7 +185,7 @@ func (server *mongoServer) acquireSocketInternal( // Record that we fetched a connection of of a socket list and how long we spent waiting stats.noticeSocketAcquisition(timeSpentWaiting) } else { - if len(server.liveSockets)-len(server.unusedSockets) >= poolLimit { + if len(server.liveSockets)-len(server.unusedSockets) >= info.PoolLimit { server.Unlock() return nil, false, errPoolLimit } @@ -202,15 +196,15 @@ func (server *mongoServer) acquireSocketInternal( socket = server.unusedSockets[n-1] server.unusedSockets[n-1] = nil // Help GC. server.unusedSockets = server.unusedSockets[:n-1] - info := server.info + serverInfo := server.info server.Unlock() - err = socket.InitialAcquire(info, timeout) + err = socket.InitialAcquire(serverInfo, info) if err != nil { continue } } else { server.Unlock() - socket, err = server.Connect(timeout) + socket, err = server.Connect(info) if err == nil { server.Lock() // We've waited for the Connect, see if we got @@ -231,20 +225,18 @@ func (server *mongoServer) acquireSocketInternal( // Connect establishes a new connection to the server. This should // generally be done through server.AcquireSocket(). -func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) { +func (server *mongoServer) Connect(info *DialInfo) (*mongoSocket, error) { server.RLock() master := server.info.Master dial := server.dial server.RUnlock() - logf("Establishing new connection to %s (timeout=%s)...", server.Addr, timeout) + logf("Establishing new connection to %s (timeout=%s)...", server.Addr, info.Timeout) var conn net.Conn var err error switch { case !dial.isSet(): - // Cannot do this because it lacks timeout support. :-( - //conn, err = net.DialTCP("tcp", nil, server.tcpaddr) - conn, err = net.DialTimeout("tcp", server.ResolvedAddr, timeout) + conn, err = net.DialTimeout("tcp", server.ResolvedAddr, info.Timeout) if tcpconn, ok := conn.(*net.TCPConn); ok { tcpconn.SetKeepAlive(true) } else if err == nil { @@ -264,7 +256,7 @@ func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) logf("Connection to %s established.", server.Addr) stats.conn(+1, master) - return newSocket(server, conn, timeout), nil + return newSocket(server, conn, info), nil } // Close forces closing all sockets that are alive, whether @@ -407,7 +399,8 @@ func (server *mongoServer) pinger(loop bool) { time.Sleep(delay) } op := op - socket, _, err := server.AcquireSocket(0, delay) + + socket, _, err := server.AcquireSocket(server.dialInfo) if err == nil { start := time.Now() _, _ = socket.SimpleQuery(&op) @@ -448,7 +441,7 @@ func (server *mongoServer) poolShrinker() { } server.Lock() unused := len(server.unusedSockets) - if unused < server.minPoolSize { + if unused < server.dialInfo.MinPoolSize { server.Unlock() continue } @@ -457,8 +450,8 @@ func (server *mongoServer) poolShrinker() { reclaimMap := map[*mongoSocket]struct{}{} // Because the acquisition and recycle are done at the tail of array, // the head is always the oldest unused socket. - for _, s := range server.unusedSockets[:unused-server.minPoolSize] { - if s.lastTimeUsed.Add(time.Duration(server.maxIdleTimeMS) * time.Millisecond).After(now) { + for _, s := range server.unusedSockets[:unused-server.dialInfo.MinPoolSize] { + if s.lastTimeUsed.Add(time.Duration(server.dialInfo.MaxIdleTimeMS) * time.Millisecond).After(now) { break } end++ @@ -572,7 +565,7 @@ func (servers *mongoServers) BestFit(mode Mode, serverTags []bson.D) *mongoServe if best == nil { best = next best.RLock() - if serverTags != nil && !next.info.Mongos && !best.hasTags(serverTags) { + if len(serverTags) != 0 && !next.info.Mongos && !best.hasTags(serverTags) { best.RUnlock() best = nil } @@ -581,7 +574,7 @@ func (servers *mongoServers) BestFit(mode Mode, serverTags []bson.D) *mongoServe next.RLock() swap := false switch { - case serverTags != nil && !next.info.Mongos && !next.hasTags(serverTags): + case len(serverTags) != 0 && !next.info.Mongos && !next.hasTags(serverTags): // Must have requested tags. case mode == Secondary && next.info.Master && !next.info.Mongos: // Must be a secondary or mongos. diff --git a/server_test.go b/server_test.go index 1d21ef08b..43ddfa3b1 100644 --- a/server_test.go +++ b/server_test.go @@ -29,8 +29,8 @@ package mgo_test import ( "time" - . "gopkg.in/check.v1" "github.com/globalsign/mgo" + . "gopkg.in/check.v1" ) func (s *S) TestServerRecoversFromAbend(c *C) { @@ -40,7 +40,13 @@ func (s *S) TestServerRecoversFromAbend(c *C) { // Peek behind the scenes cluster := session.Cluster() server := cluster.Server("127.0.0.1:40001") - sock, abended, err := server.AcquireSocket(100, time.Second) + + info := &mgo.DialInfo{ + Timeout: time.Second, + PoolLimit: 100, + } + + sock, abended, err := server.AcquireSocket(info) c.Assert(err, IsNil) c.Assert(sock, NotNil) sock.Release() @@ -49,15 +55,15 @@ func (s *S) TestServerRecoversFromAbend(c *C) { sock.Close() server.AbendSocket(sock) // Next acquire notices the connection was abnormally ended - sock, abended, err = server.AcquireSocket(100, time.Second) + sock, abended, err = server.AcquireSocket(info) c.Assert(err, IsNil) sock.Release() c.Check(abended, Equals, true) - // cluster.AcquireSocket should fix the abended problems - sock, err = cluster.AcquireSocket(mgo.Primary, false, time.Minute, time.Second, nil, 100) + // cluster.AcquireSocketWithPoolTimeout should fix the abended problems + sock, err = cluster.AcquireSocketWithPoolTimeout(mgo.Primary, false, time.Minute, nil, info) c.Assert(err, IsNil) sock.Release() - sock, abended, err = server.AcquireSocket(100, time.Second) + sock, abended, err = server.AcquireSocket(info) c.Assert(err, IsNil) c.Check(abended, Equals, false) sock.Release() diff --git a/session.go b/session.go index 5b98154f1..cd2a53e19 100644 --- a/session.go +++ b/session.go @@ -28,6 +28,7 @@ package mgo import ( "crypto/md5" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" @@ -73,6 +74,14 @@ const ( Monotonic Mode = 1 // Strong mode is specific to mgo, and is same as Primary. Strong Mode = 2 + + // DefaultConnectionPoolLimit defines the default maximum number of + // connections in the connection pool. + // + // To override this value set DialInfo.PoolLimit. + DefaultConnectionPoolLimit = 4096 + + zeroDuration = time.Duration(0) ) // mgo.v3: Drop Strong mode, suffix all modes with "Mode". @@ -90,9 +99,6 @@ type Session struct { defaultdb string sourcedb string syncTimeout time.Duration - sockTimeout time.Duration - poolLimit int - poolTimeout time.Duration consistency Mode creds []Credential dialCred *Credential @@ -104,6 +110,8 @@ type Session struct { queryConfig query bypassValidation bool slaveOk bool + + dialInfo *DialInfo } // Database holds collections of documents @@ -196,7 +204,7 @@ const ( // Dial will timeout after 10 seconds if a server isn't reached. The returned // session will timeout operations after one minute by default if servers aren't // available. To customize the timeout, see DialWithTimeout, SetSyncTimeout, and -// SetSocketTimeout. +// DialInfo Read/WriteTimeout. // // This method is generally called just once for a given cluster. Further // sessions to the same cluster are then established using the New or Copy @@ -287,6 +295,12 @@ const ( // The identifier of this client application. This parameter is used to // annotate logs / profiler output and cannot exceed 128 bytes. // +// ssl= +// +// true: Initiate the connection with TLS/SSL. +// false: Initiate the connection without TLS/SSL. +// The default value is false. +// // Relevant documentation: // // http://docs.mongodb.org/manual/reference/connection-string/ @@ -324,6 +338,7 @@ func ParseURL(url string) (*DialInfo, error) { if err != nil { return nil, err } + ssl := false direct := false mechanism := "" service := "" @@ -335,8 +350,13 @@ func ParseURL(url string) (*DialInfo, error) { var readPreferenceTagSets []bson.D minPoolSize := 0 maxIdleTimeMS := 0 + safe := Safe{} for _, opt := range uinfo.options { switch opt.key { + case "ssl": + if v, err := strconv.ParseBool(opt.value); err == nil && v { + ssl = true + } case "authSource": source = opt.value case "authMechanism": @@ -345,6 +365,23 @@ func ParseURL(url string) (*DialInfo, error) { service = opt.value case "replicaSet": setName = opt.value + case "w": + safe.WMode = opt.value + case "j": + journal, err := strconv.ParseBool(opt.value) + if err != nil { + return nil, errors.New("bad value for j: " + opt.value) + } + safe.J = journal + case "wtimeoutMS": + timeout, err := strconv.Atoi(opt.value) + if err != nil { + return nil, errors.New("bad value for wtimeoutMS: " + opt.value) + } + if timeout < 0 { + return nil, errors.New("bad value (negative) for wtimeoutMS: " + opt.value) + } + safe.WTimeout = timeout case "maxPoolSize": poolLimit, err = strconv.Atoi(opt.value) if err != nil { @@ -387,7 +424,7 @@ func ParseURL(url string) (*DialInfo, error) { return nil, errors.New("bad value for minPoolSize: " + opt.value) } if minPoolSize < 0 { - return nil, errors.New("bad value (negtive) for minPoolSize: " + opt.value) + return nil, errors.New("bad value (negative) for minPoolSize: " + opt.value) } case "maxIdleTimeMS": maxIdleTimeMS, err = strconv.Atoi(opt.value) @@ -395,7 +432,7 @@ func ParseURL(url string) (*DialInfo, error) { return nil, errors.New("bad value for maxIdleTimeMS: " + opt.value) } if maxIdleTimeMS < 0 { - return nil, errors.New("bad value (negtive) for maxIdleTimeMS: " + opt.value) + return nil, errors.New("bad value (negative) for maxIdleTimeMS: " + opt.value) } case "connect": if opt.value == "direct" { @@ -430,10 +467,18 @@ func ParseURL(url string) (*DialInfo, error) { Mode: readPreferenceMode, TagSets: readPreferenceTagSets, }, + Safe: safe, ReplicaSetName: setName, MinPoolSize: minPoolSize, MaxIdleTimeMS: maxIdleTimeMS, } + if ssl && info.DialServer == nil { + // Set DialServer only if nil, we don't want to override user's settings. + info.DialServer = func(addr *ServerAddr) (net.Conn, error) { + conn, err := tls.Dial("tcp", addr.String(), &tls.Config{}) + return conn, err + } + } return &info, nil } @@ -483,15 +528,38 @@ type DialInfo struct { Username string Password string - // PoolLimit defines the per-server socket pool limit. Defaults to 4096. - // See Session.SetPoolLimit for details. + // PoolLimit defines the per-server socket pool limit. Defaults to + // DefaultConnectionPoolLimit. See Session.SetPoolLimit for details. PoolLimit int // PoolTimeout defines max time to wait for a connection to become available - // if the pool limit is reaqched. Defaults to zero, which means forever. - // See Session.SetPoolTimeout for details + // if the pool limit is reached. Defaults to zero, which means forever. See + // Session.SetPoolTimeout for details PoolTimeout time.Duration + // ReadTimeout defines the maximum duration to wait for a response to be + // read from MongoDB. + // + // This effectively limits the maximum query execution time. If a MongoDB + // query duration exceeds this timeout, the caller will receive a timeout, + // however MongoDB will continue processing the query. This duration must be + // large enough to allow MongoDB to execute the query, and the response be + // received over the network connection. + // + // Only limits the network read - does not include unmarshalling / + // processing of the response. Defaults to DialInfo.Timeout. If 0, no + // timeout is set. + ReadTimeout time.Duration + + // WriteTimeout defines the maximum duration of a write to MongoDB over the + // network connection. + // + // This is can usually be low unless writing large documents, or over a high + // latency link. Only limits network write time - does not include + // marshalling/processing the request. Defaults to DialInfo.Timeout. If 0, + // no timeout is set. + WriteTimeout time.Duration + // The identifier of the client application which ran the operation. AppName string @@ -499,6 +567,9 @@ type DialInfo struct { // Session.SetMode and Session.SelectServers. ReadPreference *ReadPreference + // Safe mostly defines write options, though there is RMode. See Session.SetSafe + Safe Safe + // FailFast will cause connection and query attempts to fail faster when // the server is unavailable, instead of retrying until the configured // timeout period. Note that an unavailable server may silently drop @@ -515,7 +586,7 @@ type DialInfo struct { // Defaults to 0. MinPoolSize int - //The maximum number of milliseconds that a connection can remain idle in the pool + // The maximum number of milliseconds that a connection can remain idle in the pool // before being removed and closed. MaxIdleTimeMS int @@ -527,6 +598,79 @@ type DialInfo struct { Dial func(addr net.Addr) (net.Conn, error) } +// Copy returns a deep-copy of i. +func (i *DialInfo) Copy() *DialInfo { + var readPreference *ReadPreference + if i.ReadPreference != nil { + readPreference = &ReadPreference{ + Mode: i.ReadPreference.Mode, + } + readPreference.TagSets = make([]bson.D, len(i.ReadPreference.TagSets)) + copy(readPreference.TagSets, i.ReadPreference.TagSets) + } + + info := &DialInfo{ + Timeout: i.Timeout, + Database: i.Database, + ReplicaSetName: i.ReplicaSetName, + Source: i.Source, + Service: i.Service, + ServiceHost: i.ServiceHost, + Mechanism: i.Mechanism, + Username: i.Username, + Password: i.Password, + PoolLimit: i.PoolLimit, + PoolTimeout: i.PoolTimeout, + ReadTimeout: i.ReadTimeout, + WriteTimeout: i.WriteTimeout, + AppName: i.AppName, + ReadPreference: readPreference, + FailFast: i.FailFast, + Direct: i.Direct, + MinPoolSize: i.MinPoolSize, + MaxIdleTimeMS: i.MaxIdleTimeMS, + DialServer: i.DialServer, + Dial: i.Dial, + } + + info.Addrs = make([]string, len(i.Addrs)) + copy(info.Addrs, i.Addrs) + + return info +} + +// readTimeout returns the configured read timeout, or i.Timeout if it's not set +func (i *DialInfo) readTimeout() time.Duration { + if i.ReadTimeout == zeroDuration { + return i.Timeout + } + return i.ReadTimeout +} + +// writeTimeout returns the configured write timeout, or i.Timeout if it's not +// set +func (i *DialInfo) writeTimeout() time.Duration { + if i.WriteTimeout == zeroDuration { + return i.Timeout + } + return i.WriteTimeout +} + +// roundTripTimeout returns the total time allocated for a single network read +// and write. +func (i *DialInfo) roundTripTimeout() time.Duration { + return i.readTimeout() + i.writeTimeout() +} + +// poolLimit returns the configured connection pool size, or +// DefaultConnectionPoolLimit. +func (i *DialInfo) poolLimit() int { + if i.PoolLimit == 0 { + return DefaultConnectionPoolLimit + } + return i.PoolLimit +} + // ReadPreference defines the manner in which servers are chosen. type ReadPreference struct { // Mode determines the consistency of results. See Session.SetMode. @@ -556,7 +700,12 @@ func (addr *ServerAddr) TCPAddr() *net.TCPAddr { } // DialWithInfo establishes a new session to the cluster identified by info. -func DialWithInfo(info *DialInfo) (*Session, error) { +func DialWithInfo(dialInfo *DialInfo) (*Session, error) { + info := dialInfo.Copy() + info.PoolLimit = info.poolLimit() + info.ReadTimeout = info.readTimeout() + info.WriteTimeout = info.writeTimeout() + addrs := make([]string, len(info.Addrs)) for i, addr := range info.Addrs { p := strings.LastIndexAny(addr, "]:") @@ -566,8 +715,8 @@ func DialWithInfo(info *DialInfo) (*Session, error) { } addrs[i] = addr } - cluster := newCluster(addrs, info.Direct, info.FailFast, dialer{info.Dial, info.DialServer}, info.ReplicaSetName, info.AppName) - session := newSession(Eventual, cluster, info.Timeout) + cluster := newCluster(addrs, info) + session := newSession(Eventual, cluster, info) session.defaultdb = info.Database if session.defaultdb == "" { session.defaultdb = "test" @@ -595,16 +744,6 @@ func DialWithInfo(info *DialInfo) (*Session, error) { } session.creds = []Credential{*session.dialCred} } - if info.PoolLimit > 0 { - session.poolLimit = info.PoolLimit - } - - cluster.minPoolSize = info.MinPoolSize - cluster.maxIdleTimeMS = info.MaxIdleTimeMS - - if info.PoolTimeout > 0 { - session.poolTimeout = info.PoolTimeout - } cluster.Release() @@ -617,6 +756,8 @@ func DialWithInfo(info *DialInfo) (*Session, error) { return nil, err } + session.SetSafe(&info.Safe) + if info.ReadPreference != nil { session.SelectServers(info.ReadPreference.TagSets...) session.SetMode(info.ReadPreference.Mode, true) @@ -624,6 +765,8 @@ func DialWithInfo(info *DialInfo) (*Session, error) { session.SetMode(Strong, true) } + session.dialInfo = info + return session, nil } @@ -684,13 +827,12 @@ func extractURL(s string) (*urlInfo, error) { return info, nil } -func newSession(consistency Mode, cluster *mongoCluster, timeout time.Duration) (session *Session) { +func newSession(consistency Mode, cluster *mongoCluster, info *DialInfo) (session *Session) { cluster.Acquire() session = &Session{ mgoCluster: cluster, - syncTimeout: timeout, - sockTimeout: timeout, - poolLimit: 4096, + syncTimeout: info.Timeout, + dialInfo: info, } debugf("New session %p on cluster %p", session, cluster) session.SetMode(consistency, true) @@ -719,9 +861,6 @@ func copySession(session *Session, keepCreds bool) (s *Session) { defaultdb: session.defaultdb, sourcedb: session.sourcedb, syncTimeout: session.syncTimeout, - sockTimeout: session.sockTimeout, - poolLimit: session.poolLimit, - poolTimeout: session.poolTimeout, consistency: session.consistency, creds: creds, dialCred: session.dialCred, @@ -733,6 +872,7 @@ func copySession(session *Session, keepCreds bool) (s *Session) { queryConfig: session.queryConfig, bypassValidation: session.bypassValidation, slaveOk: session.slaveOk, + dialInfo: session.dialInfo, } s = &scopy debugf("New session %p on cluster %p (copy from %p)", s, cluster, session) @@ -1332,7 +1472,6 @@ type Index struct { // Collation allows users to specify language-specific rules for string comparison, // such as rules for lettercase and accent marks. type Collation struct { - // Locale defines the collation locale. Locale string `bson:"locale"` @@ -2018,13 +2157,21 @@ func (s *Session) SetSyncTimeout(d time.Duration) { s.m.Unlock() } -// SetSocketTimeout sets the amount of time to wait for a non-responding -// socket to the database before it is forcefully closed. +// SetSocketTimeout is deprecated - use DialInfo read/write timeouts instead. +// +// SetSocketTimeout sets the amount of time to wait for a non-responding socket +// to the database before it is forcefully closed. // // The default timeout is 1 minute. func (s *Session) SetSocketTimeout(d time.Duration) { s.m.Lock() - s.sockTimeout = d + + // Set both the read and write timeout, as well as the DialInfo.Timeout for + // backwards compatibility, + s.dialInfo.Timeout = d + s.dialInfo.ReadTimeout = d + s.dialInfo.WriteTimeout = d + if s.masterSocket != nil { s.masterSocket.SetTimeout(d) } @@ -2058,7 +2205,7 @@ func (s *Session) SetCursorTimeout(d time.Duration) { // of used resources and number of goroutines before they are created. func (s *Session) SetPoolLimit(limit int) { s.m.Lock() - s.poolLimit = limit + s.dialInfo.PoolLimit = limit s.m.Unlock() } @@ -2068,7 +2215,7 @@ func (s *Session) SetPoolLimit(limit int) { // The default value is zero, which means to wait forever with no timeout. func (s *Session) SetPoolTimeout(timeout time.Duration) { s.m.Lock() - s.poolTimeout = timeout + s.dialInfo.PoolTimeout = timeout s.m.Unlock() } @@ -4137,9 +4284,11 @@ func (iter *Iter) Timeout() bool { // // Next returns true if a document was successfully unmarshalled onto result, // and false at the end of the result set or if an error happened. -// When Next returns false, the Err method should be called to verify if -// there was an error during iteration, and the Timeout method to verify if the -// false return value was caused by a timeout (no available results). +// When Next returns false, either the Err method or the Close method should be +// called to verify if there was an error during iteration. While both will +// return the error (or nil), Close will also release the cursor on the server. +// The Timeout method may also be called to verify if the false return value +// was caused by a timeout (no available results). // // For example: // @@ -4147,6 +4296,9 @@ func (iter *Iter) Timeout() bool { // for iter.Next(&result) { // fmt.Printf("Result: %v\n", result.Id) // } +// if iter.Timeout() { +// // react to timeout +// } // if err := iter.Close(); err != nil { // return err // } @@ -4275,10 +4427,19 @@ func (iter *Iter) Next(result interface{}) bool { // func (iter *Iter) All(result interface{}) error { resultv := reflect.ValueOf(result) - if resultv.Kind() != reflect.Ptr || resultv.Elem().Kind() != reflect.Slice { + if resultv.Kind() != reflect.Ptr { panic("result argument must be a slice address") } + slicev := resultv.Elem() + + if slicev.Kind() == reflect.Interface { + slicev = slicev.Elem() + } + if slicev.Kind() != reflect.Slice { + panic("result argument must be a slice address") + } + slicev = slicev.Slice(0, slicev.Cap()) elemt := slicev.Type().Elem() i := 0 @@ -4357,11 +4518,13 @@ func (iter *Iter) acquireSocket() (*mongoSocket, error) { // with Eventual sessions, if a Refresh is done, or if a // monotonic session gets a write and shifts from secondary // to primary. Our cursor is in a specific server, though. + iter.session.m.Lock() - sockTimeout := iter.session.sockTimeout + info := iter.session.dialInfo iter.session.m.Unlock() + socket.Release() - socket, _, err = iter.server.AcquireSocket(0, sockTimeout) + socket, _, err = iter.server.AcquireSocket(info) if err != nil { return nil, err } @@ -4434,10 +4597,11 @@ func (iter *Iter) getMoreCmd() *queryOp { type countCmd struct { Count string Query interface{} - Limit int32 `bson:",omitempty"` - Skip int32 `bson:",omitempty"` - Hint bson.D `bson:"hint,omitempty"` - MaxTimeMS int `bson:"maxTimeMS,omitempty"` + Limit int32 `bson:",omitempty"` + Skip int32 `bson:",omitempty"` + Hint bson.D `bson:"hint,omitempty"` + MaxTimeMS int `bson:"maxTimeMS,omitempty"` + Collation *Collation `bson:"collation,omitempty"` } // Count returns the total number of documents in the result set. @@ -4463,7 +4627,7 @@ func (q *Query) Count() (n int, err error) { // simply want a Zero bson.D hint, _ := q.op.options.Hint.(bson.D) result := struct{ N int }{} - err = session.DB(dbname).Run(countCmd{cname, query, limit, op.skip, hint, op.options.MaxTimeMS}, &result) + err = session.DB(dbname).Run(countCmd{cname, query, limit, op.skip, hint, op.options.MaxTimeMS, op.options.Collation}, &result) return result.N, err } @@ -4744,11 +4908,13 @@ type findModifyCmd struct { Collection string `bson:"findAndModify"` Query, Update, Sort, Fields interface{} `bson:",omitempty"` Upsert, Remove, New bool `bson:",omitempty"` + WriteConcern interface{} `bson:"writeConcern"` } type valueResult struct { - Value bson.Raw - LastError LastError `bson:"lastErrorObject"` + Value bson.Raw + LastError LastError `bson:"lastErrorObject"` + ConcernError writeConcernError `bson:"writeConcernError"` } // Apply runs the findAndModify MongoDB command, which allows updating, upserting @@ -4756,6 +4922,8 @@ type valueResult struct { // version (the default) or the new version of the document (when ReturnNew is true). // If no objects are found Apply returns ErrNotFound. // +// If the session is in safe mode, the LastError result will be returned as err. +// // The Sort and Select query methods affect the result of Apply. In case // multiple documents match the query, Sort enables selecting which document to // act upon by ordering it first. Select enables retrieving only a selection @@ -4792,15 +4960,27 @@ func (q *Query) Apply(change Change, result interface{}) (info *ChangeInfo, err dbname := op.collection[:c] cname := op.collection[c+1:] + // https://docs.mongodb.com/manual/reference/command/findAndModify/#dbcmd.findAndModify + session.m.RLock() + safeOp := session.safeOp + session.m.RUnlock() + var writeConcern interface{} + if safeOp == nil { + writeConcern = bson.D{{Name: "w", Value: 0}} + } else { + writeConcern = safeOp.query.(*getLastError) + } + cmd := findModifyCmd{ - Collection: cname, - Update: change.Update, - Upsert: change.Upsert, - Remove: change.Remove, - New: change.ReturnNew, - Query: op.query, - Sort: op.options.OrderBy, - Fields: op.selector, + Collection: cname, + Update: change.Update, + Upsert: change.Upsert, + Remove: change.Remove, + New: change.ReturnNew, + Query: op.query, + Sort: op.options.OrderBy, + Fields: op.selector, + WriteConcern: writeConcern, } session = session.Clone() @@ -4843,6 +5023,14 @@ func (q *Query) Apply(change Change, result interface{}) (info *ChangeInfo, err } else if change.Upsert { info.UpsertedId = lerr.UpsertedId } + if doc.ConcernError.Code != 0 { + var lerr LastError + e := doc.ConcernError + lerr.Code = e.Code + lerr.Err = e.ErrMsg + err = &lerr + return info, err + } return info, nil } @@ -4951,7 +5139,11 @@ func (s *Session) acquireSocket(slaveOk bool) (*mongoSocket, error) { // Still not good. We need a new socket. sock, err := s.cluster().AcquireSocketWithPoolTimeout( - s.consistency, slaveOk && s.slaveOk, s.syncTimeout, s.sockTimeout, s.queryConfig.op.serverTags, s.poolLimit, s.poolTimeout, + s.consistency, + slaveOk && s.slaveOk, + s.syncTimeout, + s.queryConfig.op.serverTags, + s.dialInfo, ) if err != nil { return nil, err diff --git a/session_internal_test.go b/session_internal_test.go index ddce59cae..3e214b174 100644 --- a/session_internal_test.go +++ b/session_internal_test.go @@ -3,9 +3,11 @@ package mgo import ( "crypto/x509/pkix" "encoding/asn1" + "testing" + "time" + "github.com/globalsign/mgo/bson" . "gopkg.in/check.v1" - "testing" ) type S struct{} @@ -62,3 +64,22 @@ func (s *S) TestGetRFC2253NameStringMultiValued(c *C) { c.Assert(getRFC2253NameString(&RDNElements), Equals, "OU=Sales+CN=J. Smith,O=Widget Inc.,C=US") } + +func (s *S) TestDialTimeouts(c *C) { + info := &DialInfo{} + + c.Assert(info.readTimeout(), Equals, time.Duration(0)) + c.Assert(info.writeTimeout(), Equals, time.Duration(0)) + c.Assert(info.roundTripTimeout(), Equals, time.Duration(0)) + + info.Timeout = 60 * time.Second + c.Assert(info.readTimeout(), Equals, 60*time.Second) + c.Assert(info.writeTimeout(), Equals, 60*time.Second) + c.Assert(info.roundTripTimeout(), Equals, 120*time.Second) + + info.ReadTimeout = time.Second + c.Assert(info.readTimeout(), Equals, time.Second) + + info.WriteTimeout = time.Second + c.Assert(info.writeTimeout(), Equals, time.Second) +} diff --git a/session_test.go b/session_test.go index 14cb9b1a6..0a897b61d 100644 --- a/session_test.go +++ b/session_test.go @@ -87,6 +87,15 @@ func (s *S) TestPing(c *C) { c.Assert(stats.ReceivedOps, Equals, 1) } +func (s *S) TestPingSsl(c *C) { + c.Skip("this test requires the usage of the system provided certificates") + session, err := mgo.Dial("localhost:40001?ssl=true") + c.Assert(err, IsNil) + defer session.Close() + + c.Assert(session.Ping(), IsNil) +} + func (s *S) TestDialIPAddress(c *C) { session, err := mgo.Dial("127.0.0.1:40001") c.Assert(err, IsNil) @@ -135,6 +144,25 @@ func (s *S) TestURLParsing(c *C) { } } +func (s *S) TestURLSsl(c *C) { + type test struct { + url string + nilDialServer bool + } + + tests := []test{ + {"localhost:40001", true}, + {"localhost:40001?ssl=false", true}, + {"localhost:40001?ssl=true", false}, + } + + for _, test := range tests { + info, err := mgo.ParseURL(test.url) + c.Assert(err, IsNil) + c.Assert(info.DialServer == nil, Equals, test.nilDialServer) + } +} + func (s *S) TestURLReadPreference(c *C) { type test struct { url string @@ -168,6 +196,43 @@ func (s *S) TestURLInvalidReadPreference(c *C) { } } +func (s *S) TestURLSafe(c *C) { + type test struct { + url string + safe mgo.Safe + } + + tests := []test{ + {"localhost:40001?w=majority", mgo.Safe{WMode: "majority"}}, + {"localhost:40001?j=true", mgo.Safe{J: true}}, + {"localhost:40001?j=false", mgo.Safe{J: false}}, + {"localhost:40001?wtimeoutMS=1", mgo.Safe{WTimeout: 1}}, + {"localhost:40001?wtimeoutMS=1000", mgo.Safe{WTimeout: 1000}}, + {"localhost:40001?w=1&j=true&wtimeoutMS=1000", mgo.Safe{WMode: "1", J: true, WTimeout: 1000}}, + } + + for _, test := range tests { + info, err := mgo.ParseURL(test.url) + c.Assert(err, IsNil) + c.Assert(info.Safe, NotNil) + c.Assert(info.Safe, Equals, test.safe) + } +} + +func (s *S) TestURLInvalidSafe(c *C) { + urls := []string{ + "localhost:40001?wtimeoutMS=abc", + "localhost:40001?wtimeoutMS=", + "localhost:40001?wtimeoutMS=-1", + "localhost:40001?j=12", + "localhost:40001?j=foo", + } + for _, url := range urls { + _, err := mgo.ParseURL(url) + c.Assert(err, NotNil) + } +} + func (s *S) TestMinPoolSize(c *C) { tests := []struct { url string @@ -416,6 +481,18 @@ func (s *S) TestInsertFindAll(c *C) { // Ensure result is backed by the originally allocated array c.Assert(&result[0], Equals, &allocd[0]) + // Re-run test destination as a pointer to interface{} + var resultInterface interface{} + + anotherslice := make([]R, 5) + resultInterface = anotherslice + err = coll.Find(nil).Sort("a").All(&resultInterface) + c.Assert(err, IsNil) + assertResult() + + // Ensure result is backed by the originally allocated array + c.Assert(&result[0], Equals, &allocd[0]) + // Non-pointer slice error f := func() { coll.Find(nil).All(result) } c.Assert(f, Panics, "result argument must be a slice address") @@ -1321,6 +1398,37 @@ func (s *S) TestFindAndModify(c *C) { c.Assert(info, IsNil) } +func (s *S) TestFindAndModifyWriteConcern(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"id": 42}) + c.Assert(err, IsNil) + + // Tweak the safety parameters to something unachievable. + session.SetSafe(&mgo.Safe{W: 4, WTimeout: 100}) + + var ret struct { + Id uint64 `bson:"id"` + } + + change := mgo.Change{ + Update: M{"$inc": M{"id": 8}}, + ReturnNew: false, + } + info, err := coll.Find(M{"id": M{"$exists": true}}).Apply(change, &ret) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.Matched, Equals, 1) + c.Assert(ret.Id, Equals, uint64(42)) + + if s.versionAtLeast(3, 2) { + // findAndModify support writeConcern after version 3.2. + c.Assert(err, ErrorMatches, "timeout|timed out waiting for slaves|Not enough data-bearing nodes|waiting for replication timed out") + } +} + func (s *S) TestFindAndModifyBug997828(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) @@ -1523,6 +1631,38 @@ func (s *S) TestCountQuery(c *C) { c.Assert(n, Equals, 2) } +func (s *S) TestCountQueryWithCollation(c *C) { + if !s.versionAtLeast(3, 4) { + c.Skip("depends on mongodb 3.4+") + } + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + c.Assert(err, IsNil) + + collation := &mgo.Collation{ + Locale: "en", + Strength: 2, + } + err = coll.EnsureIndex(mgo.Index{ + Key: []string{"n"}, + Collation: collation, + }) + c.Assert(err, IsNil) + + ns := []string{"hello", "Hello", "hEllO"} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Find(M{"n": "hello"}).Collation(collation).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) +} + func (s *S) TestCountQuerySorted(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) diff --git a/socket.go b/socket.go index ae13e401f..9dcedf219 100644 --- a/socket.go +++ b/socket.go @@ -42,7 +42,6 @@ type mongoSocket struct { sync.Mutex server *mongoServer // nil when cached conn net.Conn - timeout time.Duration addr string // For debugging only. nextRequestId uint32 replyFuncs map[uint32]replyFunc @@ -56,6 +55,8 @@ type mongoSocket struct { closeAfterIdle bool lastTimeUsed time.Time // for time based idle socket release sendMeta sync.Once + + dialInfo *DialInfo } type queryOpFlags uint32 @@ -181,15 +182,16 @@ type requestInfo struct { replyFunc replyFunc } -func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket { +func newSocket(server *mongoServer, conn net.Conn, info *DialInfo) *mongoSocket { socket := &mongoSocket{ conn: conn, addr: server.Addr, server: server, replyFuncs: make(map[uint32]replyFunc), + dialInfo: info, } socket.gotNonce.L = &socket.Mutex - if err := socket.InitialAcquire(server.Info(), timeout); err != nil { + if err := socket.InitialAcquire(server.Info(), info); err != nil { panic("newSocket: InitialAcquire returned error: " + err.Error()) } stats.socketsAlive(+1) @@ -223,7 +225,7 @@ func (socket *mongoSocket) ServerInfo() *mongoServerInfo { // InitialAcquire obtains the first reference to the socket, either // right after the connection is made or once a recycled socket is // being put back in use. -func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error { +func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, dialInfo *DialInfo) error { socket.Lock() if socket.references > 0 { panic("Socket acquired out of cache with references") @@ -235,7 +237,7 @@ func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout t } socket.references++ socket.serverInfo = serverInfo - socket.timeout = timeout + socket.dialInfo = dialInfo stats.socketsInUse(+1) stats.socketRefs(+1) socket.Unlock() @@ -288,7 +290,8 @@ func (socket *mongoSocket) Release() { // SetTimeout changes the timeout used on socket operations. func (socket *mongoSocket) SetTimeout(d time.Duration) { socket.Lock() - socket.timeout = d + socket.dialInfo.ReadTimeout = d + socket.dialInfo.WriteTimeout = d socket.Unlock() } @@ -301,24 +304,37 @@ const ( func (socket *mongoSocket) updateDeadline(which deadlineType) { var when time.Time - if socket.timeout > 0 { - when = time.Now().Add(socket.timeout) - } - whichstr := "" + var whichStr string switch which { case readDeadline | writeDeadline: - whichstr = "read/write" + if socket.dialInfo.roundTripTimeout() == 0 { + return + } + whichStr = "read/write" + when = time.Now().Add(socket.dialInfo.roundTripTimeout()) socket.conn.SetDeadline(when) + case readDeadline: - whichstr = "read" + if socket.dialInfo.ReadTimeout == zeroDuration { + return + } + whichStr = "read" + when = time.Now().Add(socket.dialInfo.ReadTimeout) socket.conn.SetReadDeadline(when) + case writeDeadline: - whichstr = "write" + if socket.dialInfo.WriteTimeout == zeroDuration { + return + } + whichStr = "write" + when = time.Now().Add(socket.dialInfo.WriteTimeout) socket.conn.SetWriteDeadline(when) + default: panic("invalid parameter to updateDeadline") } - debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when) + + debugf("Socket %p to %s: updated %s deadline to %s", socket, socket.addr, whichStr, when) } // Close terminates the socket use.