diff --git a/cluster.go b/cluster.go index 4e54c5d81..4959a46ba 100644 --- a/cluster.go +++ b/cluster.go @@ -146,8 +146,17 @@ type isMasterResult struct { } func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResult) error { + // I'm not really sure why the timeout was hard coded to these values (I + // assume because everything is passed as a func arg, and thus this info is + // unavailable here), but leaving them as is for backwards compatibility. + config := &DialInfo{ + Timeout: 10 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + // Monotonic let's it talk to a slave and still hold the socket. - session := newSession(Monotonic, cluster, 10*time.Second) + session := newSession(Monotonic, cluster, config) session.setSocket(socket) var cmd = bson.D{{Name: "isMaster", Value: 1}} @@ -624,9 +633,7 @@ func (cluster *mongoCluster) AcquireSocket(mode Mode, slaveOk bool, syncTimeout // AcquireSocketWithPoolTimeout returns a socket to a server in the cluster. If slaveOk is // true, it will attempt to return a socket to a slave server. If it is // false, the socket will necessarily be to a master server. -func (cluster *mongoCluster) AcquireSocketWithPoolTimeout( - mode Mode, slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D, poolLimit int, poolTimeout time.Duration, -) (s *mongoSocket, err error) { +func (cluster *mongoCluster) AcquireSocketWithPoolTimeout(mode Mode, slaveOk bool, syncTimeout time.Duration, serverTags []bson.D, info *DialInfo) (s *mongoSocket, err error) { var started time.Time var syncCount uint for { @@ -670,7 +677,7 @@ func (cluster *mongoCluster) AcquireSocketWithPoolTimeout( continue } - s, abended, err := server.AcquireSocketWithBlocking(poolLimit, socketTimeout, poolTimeout) + s, abended, err := server.AcquireSocketWithBlocking(info) if err == errPoolTimeout { // No need to remove servers from the topology if acquiring a socket fails for this reason. return nil, err diff --git a/server.go b/server.go index f34624f74..d8663efd8 100644 --- a/server.go +++ b/server.go @@ -124,21 +124,23 @@ var errServerClosed = errors.New("server was closed") // use in this server is greater than the provided limit, errPoolLimit is // returned. func (server *mongoServer) AcquireSocket(poolLimit int, timeout time.Duration) (socket *mongoSocket, abended bool, err error) { - return server.acquireSocketInternal(poolLimit, timeout, false, 0*time.Millisecond) + info := &DialInfo{ + PoolLimit: poolLimit, + ReadTimeout: timeout, + WriteTimeout: timeout, + Timeout: timeout, + } + return server.acquireSocketInternal(info, false) } // AcquireSocketWithBlocking wraps AcquireSocket, but if a socket is not available, it will _not_ // return errPoolLimit. Instead, it will block waiting for a socket to become available. If poolTimeout // should elapse before a socket is available, it will return errPoolTimeout. -func (server *mongoServer) AcquireSocketWithBlocking( - poolLimit int, socketTimeout time.Duration, poolTimeout time.Duration, -) (socket *mongoSocket, abended bool, err error) { - return server.acquireSocketInternal(poolLimit, socketTimeout, true, poolTimeout) +func (server *mongoServer) AcquireSocketWithBlocking(info *DialInfo) (socket *mongoSocket, abended bool, err error) { + return server.acquireSocketInternal(info, true) } -func (server *mongoServer) acquireSocketInternal( - poolLimit int, timeout time.Duration, shouldBlock bool, poolTimeout time.Duration, -) (socket *mongoSocket, abended bool, err error) { +func (server *mongoServer) acquireSocketInternal(info *DialInfo, shouldBlock bool) (socket *mongoSocket, abended bool, err error) { for { server.Lock() abended = server.abended @@ -146,7 +148,7 @@ func (server *mongoServer) acquireSocketInternal( server.Unlock() return nil, abended, errServerClosed } - if poolLimit > 0 { + if info.poolLimit() > 0 { if shouldBlock { // Beautiful. Golang conditions don't have a WaitWithTimeout, so I've implemented the timeout // with a wait + broadcast. The broadcast will cause the loop here to re-check the timeout, @@ -158,11 +160,11 @@ func (server *mongoServer) acquireSocketInternal( // https://github.com/golang/go/issues/16620, since the lock needs to be held in _this_ goroutine. waitDone := make(chan struct{}) timeoutHit := false - if poolTimeout > 0 { + if info.PoolTimeout > 0 { go func() { select { case <-waitDone: - case <-time.After(poolTimeout): + case <-time.After(info.PoolTimeout): // timeoutHit is part of the wait condition, so needs to be changed under mutex. server.Lock() defer server.Unlock() @@ -172,7 +174,7 @@ func (server *mongoServer) acquireSocketInternal( }() } timeSpentWaiting := time.Duration(0) - for len(server.liveSockets)-len(server.unusedSockets) >= poolLimit && !timeoutHit { + for len(server.liveSockets)-len(server.unusedSockets) >= info.poolLimit() && !timeoutHit { // We only count time spent in Wait(), and not time evaluating the entire loop, // so that in the happy non-blocking path where the condition above evaluates true // first time, we record a nice round zero wait time. @@ -191,7 +193,7 @@ func (server *mongoServer) acquireSocketInternal( // Record that we fetched a connection of of a socket list and how long we spent waiting stats.noticeSocketAcquisition(timeSpentWaiting) } else { - if len(server.liveSockets)-len(server.unusedSockets) >= poolLimit { + if len(server.liveSockets)-len(server.unusedSockets) >= info.poolLimit() { server.Unlock() return nil, false, errPoolLimit } @@ -202,15 +204,15 @@ func (server *mongoServer) acquireSocketInternal( socket = server.unusedSockets[n-1] server.unusedSockets[n-1] = nil // Help GC. server.unusedSockets = server.unusedSockets[:n-1] - info := server.info + serverInfo := server.info server.Unlock() - err = socket.InitialAcquire(info, timeout) + err = socket.InitialAcquire(serverInfo, info) if err != nil { continue } } else { server.Unlock() - socket, err = server.Connect(timeout) + socket, err = server.Connect(info) if err == nil { server.Lock() // We've waited for the Connect, see if we got @@ -231,20 +233,18 @@ func (server *mongoServer) acquireSocketInternal( // Connect establishes a new connection to the server. This should // generally be done through server.AcquireSocket(). -func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) { +func (server *mongoServer) Connect(info *DialInfo) (*mongoSocket, error) { server.RLock() master := server.info.Master dial := server.dial server.RUnlock() - logf("Establishing new connection to %s (timeout=%s)...", server.Addr, timeout) + logf("Establishing new connection to %s (timeout=%s)...", server.Addr, info.Timeout) var conn net.Conn var err error switch { case !dial.isSet(): - // Cannot do this because it lacks timeout support. :-( - //conn, err = net.DialTCP("tcp", nil, server.tcpaddr) - conn, err = net.DialTimeout("tcp", server.ResolvedAddr, timeout) + conn, err = net.DialTimeout("tcp", server.ResolvedAddr, info.Timeout) if tcpconn, ok := conn.(*net.TCPConn); ok { tcpconn.SetKeepAlive(true) } else if err == nil { @@ -264,7 +264,7 @@ func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) logf("Connection to %s established.", server.Addr) stats.conn(+1, master) - return newSocket(server, conn, timeout), nil + return newSocket(server, conn, info), nil } // Close forces closing all sockets that are alive, whether diff --git a/session.go b/session.go index 167a0375d..162d015c0 100644 --- a/session.go +++ b/session.go @@ -73,6 +73,24 @@ const ( Monotonic Mode = 1 // Strong mode is specific to mgo, and is same as Primary. Strong Mode = 2 + + // DefaultConnectionPoolSize defines the default maximum number of + // connections in the connection pool. + // + // To override this value set DialInfo.PoolLimit. + DefaultConnectionPoolSize = 4096 + + // DefaultReadTimeout is set to 60 seconds for backwards compatibility. + // + // See DialInfo.ReadTimeout + DefaultReadTimeout = time.Second * 60 + + // DefaultWriteTimeout is set to 60 seconds for backwards compatibility. + // + // See DialInfo.WriteTimeout + DefaultWriteTimeout = time.Second * 60 + + zeroDuration = time.Duration(0) ) // mgo.v3: Drop Strong mode, suffix all modes with "Mode". @@ -90,9 +108,6 @@ type Session struct { defaultdb string sourcedb string syncTimeout time.Duration - sockTimeout time.Duration - poolLimit int - poolTimeout time.Duration consistency Mode creds []Credential dialCred *Credential @@ -104,6 +119,8 @@ type Session struct { queryConfig query bypassValidation bool slaveOk bool + + dialInfo *DialInfo } // Database holds collections of documents @@ -196,7 +213,7 @@ const ( // Dial will timeout after 10 seconds if a server isn't reached. The returned // session will timeout operations after one minute by default if servers aren't // available. To customize the timeout, see DialWithTimeout, SetSyncTimeout, and -// SetSocketTimeout. +// DialInfo Read/WriteTimeout. // // This method is generally called just once for a given cluster. Further // sessions to the same cluster are then established using the New or Copy @@ -483,15 +500,38 @@ type DialInfo struct { Username string Password string - // PoolLimit defines the per-server socket pool limit. Defaults to 4096. - // See Session.SetPoolLimit for details. + // PoolLimit defines the per-server socket pool limit. Defaults to + // DefaultConnectionPoolSize. See Session.SetPoolLimit for details. PoolLimit int // PoolTimeout defines max time to wait for a connection to become available - // if the pool limit is reaqched. Defaults to zero, which means forever. - // See Session.SetPoolTimeout for details + // if the pool limit is reached. Defaults to zero, which means forever. See + // Session.SetPoolTimeout for details PoolTimeout time.Duration + // ReadTimeout defines the maximum duration to wait for a response from + // MongoDB. + // + // This effectively limits the maximum query execution time. If a MongoDB + // query duration exceeds this timeout, the caller will receive a timeout, + // however MongoDB will continue processing the query. This duration must be + // large enough to allow MongoDB to execute the query, and the response be + // received over the network connection. + // + // Only limits the network read - does not include unmarshalling / + // processing of the response. Defaults to DialInfo.Timeout. If 0, no + // timeout is set. + ReadTimeout time.Duration + + // WriteTimeout defines the maximum duration of a write to MongoDB over the + // network connection. + // + // This is can usually be low unless writing large documents, or over a high + // latency link. Only limits network write time - does not include + // marshalling/processing the request. Defaults to DialInfo.Timeout. If 0, + // no timeout is set. + WriteTimeout time.Duration + // The identifier of the client application which ran the operation. AppName string @@ -527,6 +567,80 @@ type DialInfo struct { Dial func(addr net.Addr) (net.Conn, error) } +// Copy returns a deep-copy of i. +func (i *DialInfo) Copy() *DialInfo { + var readPreference *ReadPreference + if i.ReadPreference != nil { + readPreference = &ReadPreference{ + Mode: i.ReadPreference.Mode, + } + readPreference.TagSets = make([]bson.D, len(i.ReadPreference.TagSets)) + copy(readPreference.TagSets, i.ReadPreference.TagSets) + } + + info := &DialInfo{ + Timeout: i.Timeout, + Database: i.Database, + ReplicaSetName: i.ReplicaSetName, + Source: i.Source, + Service: i.Service, + ServiceHost: i.ServiceHost, + Mechanism: i.Mechanism, + Username: i.Username, + Password: i.Password, + PoolLimit: i.PoolLimit, + PoolTimeout: i.PoolTimeout, + ReadTimeout: i.ReadTimeout, + WriteTimeout: i.WriteTimeout, + AppName: i.AppName, + ReadPreference: readPreference, + FailFast: i.FailFast, + Direct: i.Direct, + MinPoolSize: i.MinPoolSize, + MaxIdleTimeMS: i.MaxIdleTimeMS, + DialServer: i.DialServer, + Dial: i.Dial, + } + + info.Addrs = make([]string, len(i.Addrs)) + copy(info.Addrs, i.Addrs) + + return info +} + +// readTimeout returns the configured read timeout, or DefaultReadTimeout if +// unset. +func (i *DialInfo) readTimeout() time.Duration { + if i.ReadTimeout == zeroDuration { + return i.Timeout + } + return i.ReadTimeout +} + +// writeTimeout returns the configured write timeout, or DefaultWriteTimeout if +// unset. +func (i *DialInfo) writeTimeout() time.Duration { + if i.WriteTimeout == zeroDuration { + return i.Timeout + } + return i.WriteTimeout +} + +// roundTripTimeout returns the total time allocated for a single network read +// and write. +func (i *DialInfo) roundTripTimeout() time.Duration { + return i.readTimeout() + i.writeTimeout() +} + +// poolLimit returns the configured connection pool size, or +// DefaultConnectionPoolSize. +func (i *DialInfo) poolLimit() int { + if i == nil || i.PoolLimit == 0 { + return DefaultConnectionPoolSize + } + return i.PoolLimit +} + // ReadPreference defines the manner in which servers are chosen. type ReadPreference struct { // Mode determines the consistency of results. See Session.SetMode. @@ -556,7 +670,8 @@ func (addr *ServerAddr) TCPAddr() *net.TCPAddr { } // DialWithInfo establishes a new session to the cluster identified by info. -func DialWithInfo(info *DialInfo) (*Session, error) { +func DialWithInfo(dialInfo *DialInfo) (*Session, error) { + info := dialInfo.Copy() addrs := make([]string, len(info.Addrs)) for i, addr := range info.Addrs { p := strings.LastIndexAny(addr, "]:") @@ -567,7 +682,7 @@ func DialWithInfo(info *DialInfo) (*Session, error) { addrs[i] = addr } cluster := newCluster(addrs, info.Direct, info.FailFast, dialer{info.Dial, info.DialServer}, info.ReplicaSetName, info.AppName) - session := newSession(Eventual, cluster, info.Timeout) + session := newSession(Eventual, cluster, info) session.defaultdb = info.Database if session.defaultdb == "" { session.defaultdb = "test" @@ -595,17 +710,9 @@ func DialWithInfo(info *DialInfo) (*Session, error) { } session.creds = []Credential{*session.dialCred} } - if info.PoolLimit > 0 { - session.poolLimit = info.PoolLimit - } cluster.minPoolSize = info.MinPoolSize cluster.maxIdleTimeMS = info.MaxIdleTimeMS - - if info.PoolTimeout > 0 { - session.poolTimeout = info.PoolTimeout - } - cluster.Release() // People get confused when we return a session that is not actually @@ -624,6 +731,8 @@ func DialWithInfo(info *DialInfo) (*Session, error) { session.SetMode(Strong, true) } + session.dialInfo = info + return session, nil } @@ -684,13 +793,12 @@ func extractURL(s string) (*urlInfo, error) { return info, nil } -func newSession(consistency Mode, cluster *mongoCluster, timeout time.Duration) (session *Session) { +func newSession(consistency Mode, cluster *mongoCluster, info *DialInfo) (session *Session) { cluster.Acquire() session = &Session{ mgoCluster: cluster, - syncTimeout: timeout, - sockTimeout: timeout, - poolLimit: 4096, + syncTimeout: info.Timeout, + dialInfo: info, } debugf("New session %p on cluster %p", session, cluster) session.SetMode(consistency, true) @@ -719,9 +827,6 @@ func copySession(session *Session, keepCreds bool) (s *Session) { defaultdb: session.defaultdb, sourcedb: session.sourcedb, syncTimeout: session.syncTimeout, - sockTimeout: session.sockTimeout, - poolLimit: session.poolLimit, - poolTimeout: session.poolTimeout, consistency: session.consistency, creds: creds, dialCred: session.dialCred, @@ -733,6 +838,7 @@ func copySession(session *Session, keepCreds bool) (s *Session) { queryConfig: session.queryConfig, bypassValidation: session.bypassValidation, slaveOk: session.slaveOk, + dialInfo: session.dialInfo, } s = &scopy debugf("New session %p on cluster %p (copy from %p)", s, cluster, session) @@ -2018,13 +2124,21 @@ func (s *Session) SetSyncTimeout(d time.Duration) { s.m.Unlock() } -// SetSocketTimeout sets the amount of time to wait for a non-responding -// socket to the database before it is forcefully closed. +// SetSocketTimeout is deprecated - use DialInfo read/write timeouts instead. +// +// SetSocketTimeout sets the amount of time to wait for a non-responding socket +// to the database before it is forcefully closed. // // The default timeout is 1 minute. func (s *Session) SetSocketTimeout(d time.Duration) { s.m.Lock() - s.sockTimeout = d + + // Set both the read and write timeout, as well as the DialInfo.Timeout for + // backwards compatibility, + s.dialInfo.Timeout = d + s.dialInfo.ReadTimeout = d + s.dialInfo.WriteTimeout = d + if s.masterSocket != nil { s.masterSocket.SetTimeout(d) } @@ -2058,7 +2172,7 @@ func (s *Session) SetCursorTimeout(d time.Duration) { // of used resources and number of goroutines before they are created. func (s *Session) SetPoolLimit(limit int) { s.m.Lock() - s.poolLimit = limit + s.dialInfo.PoolLimit = limit s.m.Unlock() } @@ -2068,7 +2182,7 @@ func (s *Session) SetPoolLimit(limit int) { // The default value is zero, which means to wait forever with no timeout. func (s *Session) SetPoolTimeout(timeout time.Duration) { s.m.Lock() - s.poolTimeout = timeout + s.dialInfo.PoolTimeout = timeout s.m.Unlock() } @@ -4356,9 +4470,11 @@ func (iter *Iter) acquireSocket() (*mongoSocket, error) { // with Eventual sessions, if a Refresh is done, or if a // monotonic session gets a write and shifts from secondary // to primary. Our cursor is in a specific server, though. + iter.session.m.Lock() - sockTimeout := iter.session.sockTimeout + sockTimeout := iter.session.dialInfo.roundTripTimeout() iter.session.m.Unlock() + socket.Release() socket, _, err = iter.server.AcquireSocket(0, sockTimeout) if err != nil { @@ -4950,7 +5066,11 @@ func (s *Session) acquireSocket(slaveOk bool) (*mongoSocket, error) { // Still not good. We need a new socket. sock, err := s.cluster().AcquireSocketWithPoolTimeout( - s.consistency, slaveOk && s.slaveOk, s.syncTimeout, s.sockTimeout, s.queryConfig.op.serverTags, s.poolLimit, s.poolTimeout, + s.consistency, + slaveOk && s.slaveOk, + s.syncTimeout, + s.queryConfig.op.serverTags, + s.dialInfo, ) if err != nil { return nil, err diff --git a/session_test.go b/session_test.go index 14cb9b1a6..c3723b909 100644 --- a/session_test.go +++ b/session_test.go @@ -4969,6 +4969,19 @@ func (s *S) TestCollationQueries(c *C) { } } +func (s *S) TestDialTimeouts(c *C) { + info := &mgo.DialInfo{} + + c.Assert(info.readTimeout(), Equals, mgo.DefaultReadTimeout) + c.Assert(info.writeTimeout(), Equals, mgo.DefaultWriteTimeout) + + info.ReadTimeout = time.Second + c.Assert(info.readTimeout(), Equals, time.Second) + + info.WriteTimewut = time.Second + c.Assert(info.writeTimeout(), Equals, time.Second) +} + // -------------------------------------------------------------------------- // Some benchmarks that require a running database. diff --git a/socket.go b/socket.go index ae13e401f..f3f477604 100644 --- a/socket.go +++ b/socket.go @@ -42,7 +42,6 @@ type mongoSocket struct { sync.Mutex server *mongoServer // nil when cached conn net.Conn - timeout time.Duration addr string // For debugging only. nextRequestId uint32 replyFuncs map[uint32]replyFunc @@ -56,6 +55,8 @@ type mongoSocket struct { closeAfterIdle bool lastTimeUsed time.Time // for time based idle socket release sendMeta sync.Once + + dialInfo *DialInfo } type queryOpFlags uint32 @@ -181,15 +182,16 @@ type requestInfo struct { replyFunc replyFunc } -func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket { +func newSocket(server *mongoServer, conn net.Conn, info *DialInfo) *mongoSocket { socket := &mongoSocket{ conn: conn, addr: server.Addr, server: server, replyFuncs: make(map[uint32]replyFunc), + dialInfo: info, } socket.gotNonce.L = &socket.Mutex - if err := socket.InitialAcquire(server.Info(), timeout); err != nil { + if err := socket.InitialAcquire(server.Info(), info); err != nil { panic("newSocket: InitialAcquire returned error: " + err.Error()) } stats.socketsAlive(+1) @@ -223,7 +225,7 @@ func (socket *mongoSocket) ServerInfo() *mongoServerInfo { // InitialAcquire obtains the first reference to the socket, either // right after the connection is made or once a recycled socket is // being put back in use. -func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error { +func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, dialInfo *DialInfo) error { socket.Lock() if socket.references > 0 { panic("Socket acquired out of cache with references") @@ -235,7 +237,7 @@ func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout t } socket.references++ socket.serverInfo = serverInfo - socket.timeout = timeout + socket.dialInfo = dialInfo stats.socketsInUse(+1) stats.socketRefs(+1) socket.Unlock() @@ -288,7 +290,8 @@ func (socket *mongoSocket) Release() { // SetTimeout changes the timeout used on socket operations. func (socket *mongoSocket) SetTimeout(d time.Duration) { socket.Lock() - socket.timeout = d + socket.dialInfo.ReadTimeout = d + socket.dialInfo.WriteTimeout = d socket.Unlock() } @@ -301,24 +304,37 @@ const ( func (socket *mongoSocket) updateDeadline(which deadlineType) { var when time.Time - if socket.timeout > 0 { - when = time.Now().Add(socket.timeout) - } - whichstr := "" + var whichstr string switch which { case readDeadline | writeDeadline: + if socket.dialInfo.roundTripTimeout() == 0 { + return + } whichstr = "read/write" + when = time.Now().Add(socket.dialInfo.roundTripTimeout()) socket.conn.SetDeadline(when) + case readDeadline: + if socket.dialInfo.readTimeout() == 0 { + return + } whichstr = "read" + when = time.Now().Add(socket.dialInfo.readTimeout()) socket.conn.SetReadDeadline(when) + case writeDeadline: + if socket.dialInfo.writeTimeout() == 0 { + return + } whichstr = "write" + when = time.Now().Add(socket.dialInfo.writeTimeout()) socket.conn.SetWriteDeadline(when) + default: panic("invalid parameter to updateDeadline") } - debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when) + + debugf("Socket %p to %s: updated %s deadline to %s", socket, socket.addr, whichstr, when) } // Close terminates the socket use.