diff --git a/balancer/rls/balancer_test.go b/balancer/rls/balancer_test.go index 3f019dcf851e..444b8b99d4a3 100644 --- a/balancer/rls/balancer_test.go +++ b/balancer/rls/balancer_test.go @@ -1030,11 +1030,7 @@ func (s) TestUpdateStatePauses(t *testing.T) { // the test would fail. Waiting for the channel to become READY here // ensures that the test does not flake because of this rare sequence of // events. - for s := cc.GetState(); s != connectivity.Ready; s = cc.GetState() { - if !cc.WaitForStateChange(ctx, s) { - t.Fatal("Timeout when waiting for connectivity state to reach READY") - } - } + testutils.AwaitState(ctx, t, cc, connectivity.Ready) // Cache the state changes seen up to this point. states0 := ccWrapper.getStates() diff --git a/call.go b/call.go index 7311d0126148..a67a3db02eb4 100644 --- a/call.go +++ b/call.go @@ -27,10 +27,10 @@ import ( // // All errors returned by Invoke are compatible with the status package. func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply any, opts ...CallOption) error { - if err := cc.idlenessMgr.onCallBegin(); err != nil { + if err := cc.idlenessMgr.OnCallBegin(); err != nil { return err } - defer cc.idlenessMgr.onCallEnd() + defer cc.idlenessMgr.OnCallEnd() // allow interceptor to see all applicable call options, which means those // configured as defaults from dial option as well as per-call options diff --git a/channelz/service/service_sktopt_test.go b/channelz/service/service_sktopt_test.go index 8ec5341cb030..1da38aa7fbf3 100644 --- a/channelz/service/service_sktopt_test.go +++ b/channelz/service/service_sktopt_test.go @@ -128,8 +128,6 @@ func protoToSocketOption(skopts []*channelzpb.SocketOption) *channelz.SocketOpti } func (s) TestGetSocketOptions(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer cleanupWrapper(czCleanup, t) ss := []*dummySocket{ { socketOptions: &channelz.SocketOptionData{ diff --git a/channelz/service/service_test.go b/channelz/service/service_test.go index 94ca6b8b35b7..38b1f7dda7d8 100644 --- a/channelz/service/service_test.go +++ b/channelz/service/service_test.go @@ -51,12 +51,6 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } -func cleanupWrapper(cleanup func() error, t *testing.T) { - if err := cleanup(); err != nil { - t.Error(err) - } -} - type protoToSocketOptFunc func([]*channelzpb.SocketOption) *channelz.SocketOptionData // protoToSocketOpt is used in function socketProtoToStruct to extract socket option @@ -311,8 +305,7 @@ func (s) TestGetTopChannels(t *testing.T) { }, {}, } - czCleanup := channelz.NewChannelzStorageForTesting() - defer cleanupWrapper(czCleanup, t) + for _, c := range tcs { id := channelz.RegisterChannel(c, nil, "") defer channelz.RemoveEntry(id) @@ -364,8 +357,7 @@ func (s) TestGetServers(t *testing.T) { lastCallStartedTimestamp: time.Now().UTC(), }, } - czCleanup := channelz.NewChannelzStorageForTesting() - defer cleanupWrapper(czCleanup, t) + for _, s := range ss { id := channelz.RegisterServer(s, "") defer channelz.RemoveEntry(id) @@ -397,8 +389,6 @@ func (s) TestGetServers(t *testing.T) { } func (s) TestGetServerSockets(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer cleanupWrapper(czCleanup, t) svrID := channelz.RegisterServer(&dummyServer{}, "") defer channelz.RemoveEntry(svrID) refNames := []string{"listen socket 1", "normal socket 1", "normal socket 2"} @@ -438,8 +428,6 @@ func (s) TestGetServerSockets(t *testing.T) { // This test makes a GetServerSockets with a non-zero start ID, and expect only // sockets with ID >= the given start ID. func (s) TestGetServerSocketsNonZeroStartID(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer cleanupWrapper(czCleanup, t) svrID := channelz.RegisterServer(&dummyServer{}, "") defer channelz.RemoveEntry(svrID) refNames := []string{"listen socket 1", "normal socket 1", "normal socket 2"} @@ -470,9 +458,6 @@ func (s) TestGetServerSocketsNonZeroStartID(t *testing.T) { } func (s) TestGetChannel(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer cleanupWrapper(czCleanup, t) - refNames := []string{"top channel 1", "nested channel 1", "sub channel 2", "nested channel 3"} ids := make([]*channelz.Identifier, 4) ids[0] = channelz.RegisterChannel(&dummyChannel{}, nil, refNames[0]) @@ -584,8 +569,7 @@ func (s) TestGetSubChannel(t *testing.T) { subchanConnectivityChange = fmt.Sprintf("Subchannel Connectivity change to %v", connectivity.Ready) subChanPickNewAddress = fmt.Sprintf("Subchannel picks a new address %q to connect", "0.0.0.0") ) - czCleanup := channelz.NewChannelzStorageForTesting() - defer cleanupWrapper(czCleanup, t) + refNames := []string{"top channel 1", "sub channel 1", "socket 1", "socket 2"} ids := make([]*channelz.Identifier, 4) ids[0] = channelz.RegisterChannel(&dummyChannel{}, nil, refNames[0]) @@ -662,8 +646,6 @@ func (s) TestGetSubChannel(t *testing.T) { } func (s) TestGetSocket(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer cleanupWrapper(czCleanup, t) ss := []*dummySocket{ { streamsStarted: 10, diff --git a/clientconn.go b/clientconn.go index b0d28c67c7d4..d53d91d5d9f3 100644 --- a/clientconn.go +++ b/clientconn.go @@ -38,6 +38,7 @@ import ( "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/idle" "google.golang.org/grpc/internal/pretty" iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/transport" @@ -266,7 +267,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * // Configure idleness support with configured idle timeout or default idle // timeout duration. Idleness can be explicitly disabled by the user, by // setting the dial option to 0. - cc.idlenessMgr = newIdlenessManager(cc, cc.dopts.idleTimeout) + cc.idlenessMgr = idle.NewManager(idle.ManagerOptions{Enforcer: (*idler)(cc), Timeout: cc.dopts.idleTimeout, Logger: logger}) // Return early for non-blocking dials. if !cc.dopts.block { @@ -317,6 +318,16 @@ func (cc *ClientConn) addTraceEvent(msg string) { channelz.AddTraceEvent(logger, cc.channelzID, 0, ted) } +type idler ClientConn + +func (i *idler) EnterIdleMode() error { + return (*ClientConn)(i).enterIdleMode() +} + +func (i *idler) ExitIdleMode() error { + return (*ClientConn)(i).exitIdleMode() +} + // exitIdleMode moves the channel out of idle mode by recreating the name // resolver and load balancer. func (cc *ClientConn) exitIdleMode() error { @@ -639,7 +650,7 @@ type ClientConn struct { channelzID *channelz.Identifier // Channelz identifier for the channel. resolverBuilder resolver.Builder // See parseTargetAndFindResolver(). balancerWrapper *ccBalancerWrapper // Uses gracefulswitch.balancer underneath. - idlenessMgr idlenessManager + idlenessMgr idle.Manager // The following provide their own synchronization, and therefore don't // require cc.mu to be held to access them. @@ -1268,7 +1279,7 @@ func (cc *ClientConn) Close() error { rWrapper.close() } if idlenessMgr != nil { - idlenessMgr.close() + idlenessMgr.Close() } for ac := range conns { diff --git a/internal/channelz/funcs.go b/internal/channelz/funcs.go index 777cbcd7921d..5395e77529cd 100644 --- a/internal/channelz/funcs.go +++ b/internal/channelz/funcs.go @@ -24,9 +24,7 @@ package channelz import ( - "context" "errors" - "fmt" "sort" "sync" "sync/atomic" @@ -40,8 +38,11 @@ const ( ) var ( - db dbWrapper - idGen idGenerator + // IDGen is the global channelz entity ID generator. It should not be used + // outside this package except by tests. + IDGen IDGenerator + + db dbWrapper // EntryPerPage defines the number of channelz entries to be shown on a web page. EntryPerPage = int64(50) curState int32 @@ -52,14 +53,14 @@ var ( func TurnOn() { if !IsOn() { db.set(newChannelMap()) - idGen.reset() + IDGen.Reset() atomic.StoreInt32(&curState, 1) } } // IsOn returns whether channelz data collection is on. func IsOn() bool { - return atomic.CompareAndSwapInt32(&curState, 1, 1) + return atomic.LoadInt32(&curState) == 1 } // SetMaxTraceEntry sets maximum number of trace entry per entity (i.e. channel/subchannel). @@ -97,43 +98,6 @@ func (d *dbWrapper) get() *channelMap { return d.DB } -// NewChannelzStorageForTesting initializes channelz data storage and id -// generator for testing purposes. -// -// Returns a cleanup function to be invoked by the test, which waits for up to -// 10s for all channelz state to be reset by the grpc goroutines when those -// entities get closed. This cleanup function helps with ensuring that tests -// don't mess up each other. -func NewChannelzStorageForTesting() (cleanup func() error) { - db.set(newChannelMap()) - idGen.reset() - - return func() error { - cm := db.get() - if cm == nil { - return nil - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - for { - cm.mu.RLock() - topLevelChannels, servers, channels, subChannels, listenSockets, normalSockets := len(cm.topLevelChannels), len(cm.servers), len(cm.channels), len(cm.subChannels), len(cm.listenSockets), len(cm.normalSockets) - cm.mu.RUnlock() - - if err := ctx.Err(); err != nil { - return fmt.Errorf("after 10s the channelz map has not been cleaned up yet, topchannels: %d, servers: %d, channels: %d, subchannels: %d, listen sockets: %d, normal sockets: %d", topLevelChannels, servers, channels, subChannels, listenSockets, normalSockets) - } - if topLevelChannels == 0 && servers == 0 && channels == 0 && subChannels == 0 && listenSockets == 0 && normalSockets == 0 { - return nil - } - <-ticker.C - } - } -} - // GetTopChannels returns a slice of top channel's ChannelMetric, along with a // boolean indicating whether there's more top channels to be queried for. // @@ -193,7 +157,7 @@ func GetServer(id int64) *ServerMetric { // // If channelz is not turned ON, the channelz database is not mutated. func RegisterChannel(c Channel, pid *Identifier, ref string) *Identifier { - id := idGen.genID() + id := IDGen.genID() var parent int64 isTopChannel := true if pid != nil { @@ -229,7 +193,7 @@ func RegisterSubChannel(c Channel, pid *Identifier, ref string) (*Identifier, er if pid == nil { return nil, errors.New("a SubChannel's parent id cannot be nil") } - id := idGen.genID() + id := IDGen.genID() if !IsOn() { return newIdentifer(RefSubChannel, id, pid), nil } @@ -251,7 +215,7 @@ func RegisterSubChannel(c Channel, pid *Identifier, ref string) (*Identifier, er // // If channelz is not turned ON, the channelz database is not mutated. func RegisterServer(s Server, ref string) *Identifier { - id := idGen.genID() + id := IDGen.genID() if !IsOn() { return newIdentifer(RefServer, id, nil) } @@ -277,7 +241,7 @@ func RegisterListenSocket(s Socket, pid *Identifier, ref string) (*Identifier, e if pid == nil { return nil, errors.New("a ListenSocket's parent id cannot be 0") } - id := idGen.genID() + id := IDGen.genID() if !IsOn() { return newIdentifer(RefListenSocket, id, pid), nil } @@ -297,7 +261,7 @@ func RegisterNormalSocket(s Socket, pid *Identifier, ref string) (*Identifier, e if pid == nil { return nil, errors.New("a NormalSocket's parent id cannot be 0") } - id := idGen.genID() + id := IDGen.genID() if !IsOn() { return newIdentifer(RefNormalSocket, id, pid), nil } @@ -776,14 +740,17 @@ func (c *channelMap) GetServer(id int64) *ServerMetric { return sm } -type idGenerator struct { +// IDGenerator is an incrementing atomic that tracks IDs for channelz entities. +type IDGenerator struct { id int64 } -func (i *idGenerator) reset() { +// Reset resets the generated ID back to zero. Should only be used at +// initialization or by tests sensitive to the ID number. +func (i *IDGenerator) Reset() { atomic.StoreInt64(&i.id, 0) } -func (i *idGenerator) genID() int64 { +func (i *IDGenerator) genID() int64 { return atomic.AddInt64(&i.id, 1) } diff --git a/internal/channelz/types.go b/internal/channelz/types.go index 7b2f350e2e64..1d4020f53795 100644 --- a/internal/channelz/types.go +++ b/internal/channelz/types.go @@ -628,6 +628,7 @@ type tracedChannel interface { type channelTrace struct { cm *channelMap + clearCalled bool createdTime time.Time eventCount int64 mu sync.Mutex @@ -656,6 +657,10 @@ func (c *channelTrace) append(e *TraceEvent) { } func (c *channelTrace) clear() { + if c.clearCalled { + return + } + c.clearCalled = true c.mu.Lock() for _, e := range c.events { if e.RefID != 0 { diff --git a/idle.go b/internal/idle/idle.go similarity index 61% rename from idle.go rename to internal/idle/idle.go index dc3dc72f6b09..6c272476e5ef 100644 --- a/idle.go +++ b/internal/idle/idle.go @@ -16,7 +16,9 @@ * */ -package grpc +// Package idle contains a component for managing idleness (entering and exiting) +// based on RPC activity. +package idle import ( "fmt" @@ -24,6 +26,8 @@ import ( "sync" "sync/atomic" "time" + + "google.golang.org/grpc/grpclog" ) // For overriding in unit tests. @@ -31,31 +35,31 @@ var timeAfterFunc = func(d time.Duration, f func()) *time.Timer { return time.AfterFunc(d, f) } -// idlenessEnforcer is the functionality provided by grpc.ClientConn to enter +// Enforcer is the functionality provided by grpc.ClientConn to enter // and exit from idle mode. -type idlenessEnforcer interface { - exitIdleMode() error - enterIdleMode() error +type Enforcer interface { + ExitIdleMode() error + EnterIdleMode() error } -// idlenessManager defines the functionality required to track RPC activity on a +// Manager defines the functionality required to track RPC activity on a // channel. -type idlenessManager interface { - onCallBegin() error - onCallEnd() - close() +type Manager interface { + OnCallBegin() error + OnCallEnd() + Close() } -type noopIdlenessManager struct{} +type noopManager struct{} -func (noopIdlenessManager) onCallBegin() error { return nil } -func (noopIdlenessManager) onCallEnd() {} -func (noopIdlenessManager) close() {} +func (noopManager) OnCallBegin() error { return nil } +func (noopManager) OnCallEnd() {} +func (noopManager) Close() {} -// idlenessManagerImpl implements the idlenessManager interface. It uses atomic -// operations to synchronize access to shared state and a mutex to guarantee -// mutual exclusion in a critical section. -type idlenessManagerImpl struct { +// manager implements the Manager interface. It uses atomic operations to +// synchronize access to shared state and a mutex to guarantee mutual exclusion +// in a critical section. +type manager struct { // State accessed atomically. lastCallEndTime int64 // Unix timestamp in nanos; time when the most recent RPC completed. activeCallsCount int32 // Count of active RPCs; -math.MaxInt32 means channel is idle or is trying to get there. @@ -64,14 +68,15 @@ type idlenessManagerImpl struct { // Can be accessed without atomics or mutex since these are set at creation // time and read-only after that. - enforcer idlenessEnforcer // Functionality provided by grpc.ClientConn. - timeout int64 // Idle timeout duration nanos stored as an int64. + enforcer Enforcer // Functionality provided by grpc.ClientConn. + timeout int64 // Idle timeout duration nanos stored as an int64. + logger grpclog.LoggerV2 // idleMu is used to guarantee mutual exclusion in two scenarios: // - Opposing intentions: // - a: Idle timeout has fired and handleIdleTimeout() is trying to put // the channel in idle mode because the channel has been inactive. - // - b: At the same time an RPC is made on the channel, and onCallBegin() + // - b: At the same time an RPC is made on the channel, and OnCallBegin() // is trying to prevent the channel from going idle. // - Competing intentions: // - The channel is in idle mode and there are multiple RPCs starting at @@ -83,28 +88,37 @@ type idlenessManagerImpl struct { timer *time.Timer } -// newIdlenessManager creates a new idleness manager implementation for the +// ManagerOptions is a collection of options used by +// NewManager. +type ManagerOptions struct { + Enforcer Enforcer + Timeout time.Duration + Logger grpclog.LoggerV2 +} + +// NewManager creates a new idleness manager implementation for the // given idle timeout. -func newIdlenessManager(enforcer idlenessEnforcer, idleTimeout time.Duration) idlenessManager { - if idleTimeout == 0 { - return noopIdlenessManager{} +func NewManager(opts ManagerOptions) Manager { + if opts.Timeout == 0 { + return noopManager{} } - i := &idlenessManagerImpl{ - enforcer: enforcer, - timeout: int64(idleTimeout), + m := &manager{ + enforcer: opts.Enforcer, + timeout: int64(opts.Timeout), + logger: opts.Logger, } - i.timer = timeAfterFunc(idleTimeout, i.handleIdleTimeout) - return i + m.timer = timeAfterFunc(opts.Timeout, m.handleIdleTimeout) + return m } // resetIdleTimer resets the idle timer to the given duration. This method // should only be called from the timer callback. -func (i *idlenessManagerImpl) resetIdleTimer(d time.Duration) { - i.idleMu.Lock() - defer i.idleMu.Unlock() +func (m *manager) resetIdleTimer(d time.Duration) { + m.idleMu.Lock() + defer m.idleMu.Unlock() - if i.timer == nil { + if m.timer == nil { // Only close sets timer to nil. We are done. return } @@ -112,47 +126,47 @@ func (i *idlenessManagerImpl) resetIdleTimer(d time.Duration) { // It is safe to ignore the return value from Reset() because this method is // only ever called from the timer callback, which means the timer has // already fired. - i.timer.Reset(d) + m.timer.Reset(d) } // handleIdleTimeout is the timer callback that is invoked upon expiry of the // configured idle timeout. The channel is considered inactive if there are no // ongoing calls and no RPC activity since the last time the timer fired. -func (i *idlenessManagerImpl) handleIdleTimeout() { - if i.isClosed() { +func (m *manager) handleIdleTimeout() { + if m.isClosed() { return } - if atomic.LoadInt32(&i.activeCallsCount) > 0 { - i.resetIdleTimer(time.Duration(i.timeout)) + if atomic.LoadInt32(&m.activeCallsCount) > 0 { + m.resetIdleTimer(time.Duration(m.timeout)) return } // There has been activity on the channel since we last got here. Reset the // timer and return. - if atomic.LoadInt32(&i.activeSinceLastTimerCheck) == 1 { + if atomic.LoadInt32(&m.activeSinceLastTimerCheck) == 1 { // Set the timer to fire after a duration of idle timeout, calculated // from the time the most recent RPC completed. - atomic.StoreInt32(&i.activeSinceLastTimerCheck, 0) - i.resetIdleTimer(time.Duration(atomic.LoadInt64(&i.lastCallEndTime) + i.timeout - time.Now().UnixNano())) + atomic.StoreInt32(&m.activeSinceLastTimerCheck, 0) + m.resetIdleTimer(time.Duration(atomic.LoadInt64(&m.lastCallEndTime) + m.timeout - time.Now().UnixNano())) return } // This CAS operation is extremely likely to succeed given that there has // been no activity since the last time we were here. Setting the - // activeCallsCount to -math.MaxInt32 indicates to onCallBegin() that the + // activeCallsCount to -math.MaxInt32 indicates to OnCallBegin() that the // channel is either in idle mode or is trying to get there. - if !atomic.CompareAndSwapInt32(&i.activeCallsCount, 0, -math.MaxInt32) { + if !atomic.CompareAndSwapInt32(&m.activeCallsCount, 0, -math.MaxInt32) { // This CAS operation can fail if an RPC started after we checked for // activity at the top of this method, or one was ongoing from before // the last time we were here. In both case, reset the timer and return. - i.resetIdleTimer(time.Duration(i.timeout)) + m.resetIdleTimer(time.Duration(m.timeout)) return } // Now that we've set the active calls count to -math.MaxInt32, it's time to // actually move to idle mode. - if i.tryEnterIdleMode() { + if m.tryEnterIdleMode() { // Successfully entered idle mode. No timer needed until we exit idle. return } @@ -160,8 +174,8 @@ func (i *idlenessManagerImpl) handleIdleTimeout() { // Failed to enter idle mode due to a concurrent RPC that kept the channel // active, or because of an error from the channel. Undo the attempt to // enter idle, and reset the timer to try again later. - atomic.AddInt32(&i.activeCallsCount, math.MaxInt32) - i.resetIdleTimer(time.Duration(i.timeout)) + atomic.AddInt32(&m.activeCallsCount, math.MaxInt32) + m.resetIdleTimer(time.Duration(m.timeout)) } // tryEnterIdleMode instructs the channel to enter idle mode. But before @@ -171,15 +185,15 @@ func (i *idlenessManagerImpl) handleIdleTimeout() { // Return value indicates whether or not the channel moved to idle mode. // // Holds idleMu which ensures mutual exclusion with exitIdleMode. -func (i *idlenessManagerImpl) tryEnterIdleMode() bool { - i.idleMu.Lock() - defer i.idleMu.Unlock() +func (m *manager) tryEnterIdleMode() bool { + m.idleMu.Lock() + defer m.idleMu.Unlock() - if atomic.LoadInt32(&i.activeCallsCount) != -math.MaxInt32 { + if atomic.LoadInt32(&m.activeCallsCount) != -math.MaxInt32 { // We raced and lost to a new RPC. Very rare, but stop entering idle. return false } - if atomic.LoadInt32(&i.activeSinceLastTimerCheck) == 1 { + if atomic.LoadInt32(&m.activeSinceLastTimerCheck) == 1 { // An very short RPC could have come in (and also finished) after we // checked for calls count and activity in handleIdleTimeout(), but // before the CAS operation. So, we need to check for activity again. @@ -189,99 +203,99 @@ func (i *idlenessManagerImpl) tryEnterIdleMode() bool { // No new RPCs have come in since we last set the active calls count value // -math.MaxInt32 in the timer callback. And since we have the lock, it is // safe to enter idle mode now. - if err := i.enforcer.enterIdleMode(); err != nil { - logger.Errorf("Failed to enter idle mode: %v", err) + if err := m.enforcer.EnterIdleMode(); err != nil { + m.logger.Errorf("Failed to enter idle mode: %v", err) return false } // Successfully entered idle mode. - i.actuallyIdle = true + m.actuallyIdle = true return true } -// onCallBegin is invoked at the start of every RPC. -func (i *idlenessManagerImpl) onCallBegin() error { - if i.isClosed() { +// OnCallBegin is invoked at the start of every RPC. +func (m *manager) OnCallBegin() error { + if m.isClosed() { return nil } - if atomic.AddInt32(&i.activeCallsCount, 1) > 0 { + if atomic.AddInt32(&m.activeCallsCount, 1) > 0 { // Channel is not idle now. Set the activity bit and allow the call. - atomic.StoreInt32(&i.activeSinceLastTimerCheck, 1) + atomic.StoreInt32(&m.activeSinceLastTimerCheck, 1) return nil } // Channel is either in idle mode or is in the process of moving to idle // mode. Attempt to exit idle mode to allow this RPC. - if err := i.exitIdleMode(); err != nil { + if err := m.exitIdleMode(); err != nil { // Undo the increment to calls count, and return an error causing the // RPC to fail. - atomic.AddInt32(&i.activeCallsCount, -1) + atomic.AddInt32(&m.activeCallsCount, -1) return err } - atomic.StoreInt32(&i.activeSinceLastTimerCheck, 1) + atomic.StoreInt32(&m.activeSinceLastTimerCheck, 1) return nil } // exitIdleMode instructs the channel to exit idle mode. // // Holds idleMu which ensures mutual exclusion with tryEnterIdleMode. -func (i *idlenessManagerImpl) exitIdleMode() error { - i.idleMu.Lock() - defer i.idleMu.Unlock() +func (m *manager) exitIdleMode() error { + m.idleMu.Lock() + defer m.idleMu.Unlock() - if !i.actuallyIdle { + if !m.actuallyIdle { // This can happen in two scenarios: // - handleIdleTimeout() set the calls count to -math.MaxInt32 and called // tryEnterIdleMode(). But before the latter could grab the lock, an RPC - // came in and onCallBegin() noticed that the calls count is negative. + // came in and OnCallBegin() noticed that the calls count is negative. // - Channel is in idle mode, and multiple new RPCs come in at the same - // time, all of them notice a negative calls count in onCallBegin and get + // time, all of them notice a negative calls count in OnCallBegin and get // here. The first one to get the lock would got the channel to exit idle. // // Either way, nothing to do here. return nil } - if err := i.enforcer.exitIdleMode(); err != nil { + if err := m.enforcer.ExitIdleMode(); err != nil { return fmt.Errorf("channel failed to exit idle mode: %v", err) } // Undo the idle entry process. This also respects any new RPC attempts. - atomic.AddInt32(&i.activeCallsCount, math.MaxInt32) - i.actuallyIdle = false + atomic.AddInt32(&m.activeCallsCount, math.MaxInt32) + m.actuallyIdle = false // Start a new timer to fire after the configured idle timeout. - i.timer = timeAfterFunc(time.Duration(i.timeout), i.handleIdleTimeout) + m.timer = timeAfterFunc(time.Duration(m.timeout), m.handleIdleTimeout) return nil } -// onCallEnd is invoked at the end of every RPC. -func (i *idlenessManagerImpl) onCallEnd() { - if i.isClosed() { +// OnCallEnd is invoked at the end of every RPC. +func (m *manager) OnCallEnd() { + if m.isClosed() { return } // Record the time at which the most recent call finished. - atomic.StoreInt64(&i.lastCallEndTime, time.Now().UnixNano()) + atomic.StoreInt64(&m.lastCallEndTime, time.Now().UnixNano()) // Decrement the active calls count. This count can temporarily go negative // when the timer callback is in the process of moving the channel to idle // mode, but one or more RPCs come in and complete before the timer callback // can get done with the process of moving to idle mode. - atomic.AddInt32(&i.activeCallsCount, -1) + atomic.AddInt32(&m.activeCallsCount, -1) } -func (i *idlenessManagerImpl) isClosed() bool { - return atomic.LoadInt32(&i.closed) == 1 +func (m *manager) isClosed() bool { + return atomic.LoadInt32(&m.closed) == 1 } -func (i *idlenessManagerImpl) close() { - atomic.StoreInt32(&i.closed, 1) +func (m *manager) Close() { + atomic.StoreInt32(&m.closed, 1) - i.idleMu.Lock() - i.timer.Stop() - i.timer = nil - i.idleMu.Unlock() + m.idleMu.Lock() + m.timer.Stop() + m.timer = nil + m.idleMu.Unlock() } diff --git a/test/idleness_test.go b/internal/idle/idle_e2e_test.go similarity index 90% rename from test/idleness_test.go rename to internal/idle/idle_e2e_test.go index 78f19edceb31..de88046bc362 100644 --- a/test/idleness_test.go +++ b/internal/idle/idle_e2e_test.go @@ -16,7 +16,7 @@ * */ -package test +package idle_test import ( "context" @@ -31,7 +31,9 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/status" @@ -40,7 +42,23 @@ import ( testpb "google.golang.org/grpc/interop/grpc_testing" ) -const defaultTestShortIdleTimeout = 500 * time.Millisecond +func init() { + channelz.TurnOn() +} + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +const ( + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 100 * time.Millisecond + defaultTestShortIdleTimeout = 500 * time.Millisecond +) // channelzTraceEventFound looks up the top-channels in channelz (expects a // single one), and checks if there is a trace event on the channel matching the @@ -84,10 +102,6 @@ func channelzTraceEventNotFound(ctx context.Context, wantDesc string) error { // Tests the case where channel idleness is disabled by passing an idle_timeout // of 0. Verifies that a READY channel with no RPCs does not move to IDLE. func (s) TestChannelIdleness_Disabled_NoActivity(t *testing.T) { - // Setup channelz for testing. - czCleanup := channelz.NewChannelzStorageForTesting() - t.Cleanup(func() { czCleanupWrapper(czCleanup, t) }) - // Create a ClientConn with idle_timeout set to 0. r := manual.NewBuilderWithScheme("whatever") dopts := []grpc.DialOption{ @@ -110,12 +124,12 @@ func (s) TestChannelIdleness_Disabled_NoActivity(t *testing.T) { // Verify that the ClientConn moves to READY. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) // Verify that the ClientConn stay in READY. sCtx, sCancel := context.WithTimeout(ctx, 3*defaultTestShortIdleTimeout) defer sCancel() - awaitNoStateChange(sCtx, t, cc, connectivity.Ready) + testutils.AwaitNoStateChange(sCtx, t, cc, connectivity.Ready) // Verify that there are no idleness related channelz events. if err := channelzTraceEventNotFound(ctx, "entering idle mode"); err != nil { @@ -129,10 +143,6 @@ func (s) TestChannelIdleness_Disabled_NoActivity(t *testing.T) { // Tests the case where channel idleness is enabled by passing a small value for // idle_timeout. Verifies that a READY channel with no RPCs moves to IDLE. func (s) TestChannelIdleness_Enabled_NoActivity(t *testing.T) { - // Setup channelz for testing. - czCleanup := channelz.NewChannelzStorageForTesting() - t.Cleanup(func() { czCleanupWrapper(czCleanup, t) }) - // Create a ClientConn with a short idle_timeout. r := manual.NewBuilderWithScheme("whatever") dopts := []grpc.DialOption{ @@ -155,10 +165,10 @@ func (s) TestChannelIdleness_Enabled_NoActivity(t *testing.T) { // Verify that the ClientConn moves to READY. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) // Verify that the ClientConn moves to IDLE as there is no activity. - awaitState(ctx, t, cc, connectivity.Idle) + testutils.AwaitState(ctx, t, cc, connectivity.Idle) // Verify idleness related channelz events. if err := channelzTraceEventFound(ctx, "entering idle mode"); err != nil { @@ -169,10 +179,6 @@ func (s) TestChannelIdleness_Enabled_NoActivity(t *testing.T) { // Tests the case where channel idleness is enabled by passing a small value for // idle_timeout. Verifies that a READY channel with an ongoing RPC stays READY. func (s) TestChannelIdleness_Enabled_OngoingCall(t *testing.T) { - // Setup channelz for testing. - czCleanup := channelz.NewChannelzStorageForTesting() - t.Cleanup(func() { czCleanupWrapper(czCleanup, t) }) - // Create a ClientConn with a short idle_timeout. r := manual.NewBuilderWithScheme("whatever") dopts := []grpc.DialOption{ @@ -206,17 +212,18 @@ func (s) TestChannelIdleness_Enabled_OngoingCall(t *testing.T) { // Verify that the ClientConn moves to READY. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) // Spawn a goroutine which checks expected state transitions and idleness // channelz trace events. It eventually closes `blockCh`, thereby unblocking // the server RPC handler and the unary call below. errCh := make(chan error, 1) go func() { - // Verify that the ClientConn stay in READY. + defer close(blockCh) + // Verify that the ClientConn stays in READY. sCtx, sCancel := context.WithTimeout(ctx, 3*defaultTestShortIdleTimeout) defer sCancel() - awaitNoStateChange(sCtx, t, cc, connectivity.Ready) + testutils.AwaitNoStateChange(sCtx, t, cc, connectivity.Ready) // Verify that there are no idleness related channelz events. if err := channelzTraceEventNotFound(ctx, "entering idle mode"); err != nil { @@ -229,7 +236,6 @@ func (s) TestChannelIdleness_Enabled_OngoingCall(t *testing.T) { } // Unblock the unary RPC on the server. - close(blockCh) errCh <- nil }() @@ -254,10 +260,6 @@ func (s) TestChannelIdleness_Enabled_OngoingCall(t *testing.T) { // idle_timeout. Verifies that activity on a READY channel (frequent and short // RPCs) keeps it from moving to IDLE. func (s) TestChannelIdleness_Enabled_ActiveSinceLastCheck(t *testing.T) { - // Setup channelz for testing. - czCleanup := channelz.NewChannelzStorageForTesting() - t.Cleanup(func() { czCleanupWrapper(czCleanup, t) }) - // Create a ClientConn with a short idle_timeout. r := manual.NewBuilderWithScheme("whatever") dopts := []grpc.DialOption{ @@ -280,7 +282,7 @@ func (s) TestChannelIdleness_Enabled_ActiveSinceLastCheck(t *testing.T) { // Verify that the ClientConn moves to READY. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) // For a duration of three times the configured idle timeout, making RPCs // every now and then and ensure that the channel does not move out of @@ -303,7 +305,7 @@ func (s) TestChannelIdleness_Enabled_ActiveSinceLastCheck(t *testing.T) { }() // Verify that the ClientConn stay in READY. - awaitNoStateChange(sCtx, t, cc, connectivity.Ready) + testutils.AwaitNoStateChange(sCtx, t, cc, connectivity.Ready) // Verify that there are no idleness related channelz events. if err := channelzTraceEventNotFound(ctx, "entering idle mode"); err != nil { @@ -318,10 +320,6 @@ func (s) TestChannelIdleness_Enabled_ActiveSinceLastCheck(t *testing.T) { // idle_timeout. Verifies that a READY channel with no RPCs moves to IDLE. Also // verifies that a subsequent RPC on the IDLE channel kicks it out of IDLE. func (s) TestChannelIdleness_Enabled_ExitIdleOnRPC(t *testing.T) { - // Setup channelz for testing. - czCleanup := channelz.NewChannelzStorageForTesting() - t.Cleanup(func() { czCleanupWrapper(czCleanup, t) }) - // Start a test backend and set the bootstrap state of the resolver to // include this address. This will ensure that when the resolver is // restarted when exiting idle, it will push the same address to grpc again. @@ -346,10 +344,10 @@ func (s) TestChannelIdleness_Enabled_ExitIdleOnRPC(t *testing.T) { // Verify that the ClientConn moves to READY. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) // Verify that the ClientConn moves to IDLE as there is no activity. - awaitState(ctx, t, cc, connectivity.Idle) + testutils.AwaitState(ctx, t, cc, connectivity.Idle) // Verify idleness related channelz events. if err := channelzTraceEventFound(ctx, "entering idle mode"); err != nil { @@ -362,7 +360,7 @@ func (s) TestChannelIdleness_Enabled_ExitIdleOnRPC(t *testing.T) { if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Fatalf("EmptyCall RPC failed: %v", err) } - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) if err := channelzTraceEventFound(ctx, "exiting idle mode"); err != nil { t.Fatal(err) } @@ -380,10 +378,6 @@ func (s) TestChannelIdleness_Enabled_ExitIdleOnRPC(t *testing.T) { // // In either of these cases, all RPCs must succeed. func (s) TestChannelIdleness_Enabled_IdleTimeoutRacesWithRPCs(t *testing.T) { - // Setup channelz for testing. - czCleanup := channelz.NewChannelzStorageForTesting() - t.Cleanup(func() { czCleanupWrapper(czCleanup, t) }) - // Start a test backend and set the bootstrap state of the resolver to // include this address. This will ensure that when the resolver is // restarted when exiting idle, it will push the same address to grpc again. @@ -452,11 +446,11 @@ func (s) TestChannelIdleness_Connect(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.Idle) + testutils.AwaitState(ctx, t, cc, connectivity.Idle) // Connect should exit channel idleness. cc.Connect() // Verify that the ClientConn moves back to READY. - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) } diff --git a/idle_test.go b/internal/idle/idle_test.go similarity index 69% rename from idle_test.go rename to internal/idle/idle_test.go index 9b60cb5a5a1a..22bde3ba1422 100644 --- a/idle_test.go +++ b/internal/idle/idle_test.go @@ -16,7 +16,7 @@ * */ -package grpc +package idle import ( "context" @@ -25,32 +25,44 @@ import ( "sync/atomic" "testing" "time" + + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/grpctest" ) const ( + defaultTestTimeout = 10 * time.Second defaultTestIdleTimeout = 500 * time.Millisecond // A short idle_timeout for tests. defaultTestShortTimeout = 10 * time.Millisecond // A small deadline to wait for events expected to not happen. ) -type testIdlenessEnforcer struct { +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +type testEnforcer struct { exitIdleCh chan struct{} enterIdleCh chan struct{} } -func (ti *testIdlenessEnforcer) exitIdleMode() error { +func (ti *testEnforcer) ExitIdleMode() error { ti.exitIdleCh <- struct{}{} return nil } -func (ti *testIdlenessEnforcer) enterIdleMode() error { +func (ti *testEnforcer) EnterIdleMode() error { ti.enterIdleCh <- struct{}{} return nil } -func newTestIdlenessEnforcer() *testIdlenessEnforcer { - return &testIdlenessEnforcer{ +func newTestEnforcer() *testEnforcer { + return &testEnforcer{ exitIdleCh: make(chan struct{}, 1), enterIdleCh: make(chan struct{}, 1), } @@ -76,19 +88,19 @@ func overrideNewTimer(t *testing.T) <-chan struct{} { return ch } -// TestIdlenessManager_Disabled tests the case where the idleness manager is +// TestManager_Disabled tests the case where the idleness manager is // disabled by passing an idle_timeout of 0. Verifies the following things: // - timer callback does not fire -// - an RPC does not trigger a call to exitIdleMode on the ClientConn +// - an RPC does not trigger a call to ExitIdleMode on the ClientConn // - more calls to RPC termination (as compared to RPC initiation) does not // result in an error log -func (s) TestIdlenessManager_Disabled(t *testing.T) { +func (s) TestManager_Disabled(t *testing.T) { callbackCh := overrideNewTimer(t) // Create an idleness manager that is disabled because of idleTimeout being // set to `0`. - enforcer := newTestIdlenessEnforcer() - mgr := newIdlenessManager(enforcer, time.Duration(0)) + enforcer := newTestEnforcer() + mgr := NewManager(ManagerOptions{Enforcer: enforcer, Timeout: time.Duration(0), Logger: grpclog.Component("test")}) // Ensure that the timer callback does not fire within a short deadline. select { @@ -97,36 +109,36 @@ func (s) TestIdlenessManager_Disabled(t *testing.T) { case <-time.After(defaultTestShortTimeout): } - // The first invocation of onCallBegin() would lead to a call to - // exitIdleMode() on the enforcer, unless the idleness manager is disabled. - mgr.onCallBegin() + // The first invocation of OnCallBegin() would lead to a call to + // ExitIdleMode() on the enforcer, unless the idleness manager is disabled. + mgr.OnCallBegin() select { case <-enforcer.exitIdleCh: - t.Fatalf("exitIdleMode() called on enforcer when manager is disabled") + t.Fatalf("ExitIdleMode() called on enforcer when manager is disabled") case <-time.After(defaultTestShortTimeout): } - // If the number of calls to onCallEnd() exceeds the number of calls to - // onCallBegin(), the idleness manager is expected to throw an error log + // If the number of calls to OnCallEnd() exceeds the number of calls to + // OnCallBegin(), the idleness manager is expected to throw an error log // (which will cause our TestLogger to fail the test). But since the manager // is disabled, this should not happen. - mgr.onCallEnd() - mgr.onCallEnd() + mgr.OnCallEnd() + mgr.OnCallEnd() // The idleness manager is explicitly not closed here. But since the manager // is disabled, it will not start the run goroutine, and hence we expect the // leakchecker to not find any leaked goroutines. } -// TestIdlenessManager_Enabled_TimerFires tests the case where the idle manager +// TestManager_Enabled_TimerFires tests the case where the idle manager // is enabled. Ensures that when there are no RPCs, the timer callback is -// invoked and the enterIdleMode() method is invoked on the enforcer. -func (s) TestIdlenessManager_Enabled_TimerFires(t *testing.T) { +// invoked and the EnterIdleMode() method is invoked on the enforcer. +func (s) TestManager_Enabled_TimerFires(t *testing.T) { callbackCh := overrideNewTimer(t) - enforcer := newTestIdlenessEnforcer() - mgr := newIdlenessManager(enforcer, time.Duration(defaultTestIdleTimeout)) - defer mgr.close() + enforcer := newTestEnforcer() + mgr := NewManager(ManagerOptions{Enforcer: enforcer, Timeout: time.Duration(defaultTestIdleTimeout), Logger: grpclog.Component("test")}) + defer mgr.Close() // Ensure that the timer callback fires within a appropriate amount of time. select { @@ -143,23 +155,23 @@ func (s) TestIdlenessManager_Enabled_TimerFires(t *testing.T) { } } -// TestIdlenessManager_Enabled_OngoingCall tests the case where the idle manager +// TestManager_Enabled_OngoingCall tests the case where the idle manager // is enabled. Ensures that when there is an ongoing RPC, the channel does not // enter idle mode. -func (s) TestIdlenessManager_Enabled_OngoingCall(t *testing.T) { +func (s) TestManager_Enabled_OngoingCall(t *testing.T) { callbackCh := overrideNewTimer(t) - enforcer := newTestIdlenessEnforcer() - mgr := newIdlenessManager(enforcer, time.Duration(defaultTestIdleTimeout)) - defer mgr.close() + enforcer := newTestEnforcer() + mgr := NewManager(ManagerOptions{Enforcer: enforcer, Timeout: time.Duration(defaultTestIdleTimeout), Logger: grpclog.Component("test")}) + defer mgr.Close() // Fire up a goroutine that simulates an ongoing RPC that is terminated // after the timer callback fires for the first time. timerFired := make(chan struct{}) go func() { - mgr.onCallBegin() + mgr.OnCallBegin() <-timerFired - mgr.onCallEnd() + mgr.OnCallEnd() }() // Ensure that the timer callback fires and unblock the above goroutine. @@ -174,7 +186,7 @@ func (s) TestIdlenessManager_Enabled_OngoingCall(t *testing.T) { // mode since we had an ongoing RPC. select { case <-enforcer.enterIdleCh: - t.Fatalf("enterIdleMode() called on enforcer when active RPC exists") + t.Fatalf("EnterIdleMode() called on enforcer when active RPC exists") case <-time.After(defaultTestShortTimeout): } @@ -187,24 +199,24 @@ func (s) TestIdlenessManager_Enabled_OngoingCall(t *testing.T) { } } -// TestIdlenessManager_Enabled_ActiveSinceLastCheck tests the case where the +// TestManager_Enabled_ActiveSinceLastCheck tests the case where the // idle manager is enabled. Ensures that when there are active RPCs in the last // period (even though there is no active call when the timer fires), the // channel does not enter idle mode. -func (s) TestIdlenessManager_Enabled_ActiveSinceLastCheck(t *testing.T) { +func (s) TestManager_Enabled_ActiveSinceLastCheck(t *testing.T) { callbackCh := overrideNewTimer(t) - enforcer := newTestIdlenessEnforcer() - mgr := newIdlenessManager(enforcer, time.Duration(defaultTestIdleTimeout)) - defer mgr.close() + enforcer := newTestEnforcer() + mgr := NewManager(ManagerOptions{Enforcer: enforcer, Timeout: time.Duration(defaultTestIdleTimeout), Logger: grpclog.Component("test")}) + defer mgr.Close() // Fire up a goroutine that simulates unary RPCs until the timer callback // fires. timerFired := make(chan struct{}) go func() { for ; ; <-time.After(defaultTestShortTimeout) { - mgr.onCallBegin() - mgr.onCallEnd() + mgr.OnCallBegin() + mgr.OnCallEnd() select { case <-timerFired: @@ -225,7 +237,7 @@ func (s) TestIdlenessManager_Enabled_ActiveSinceLastCheck(t *testing.T) { } select { case <-enforcer.enterIdleCh: - t.Fatalf("enterIdleMode() called on enforcer when one RPC completed in the last period") + t.Fatalf("EnterIdleMode() called on enforcer when one RPC completed in the last period") case <-time.After(defaultTestShortTimeout): } @@ -238,15 +250,15 @@ func (s) TestIdlenessManager_Enabled_ActiveSinceLastCheck(t *testing.T) { } } -// TestIdlenessManager_Enabled_ExitIdleOnRPC tests the case where the idle +// TestManager_Enabled_ExitIdleOnRPC tests the case where the idle // manager is enabled. Ensures that the channel moves out of idle when an RPC is // initiated. -func (s) TestIdlenessManager_Enabled_ExitIdleOnRPC(t *testing.T) { +func (s) TestManager_Enabled_ExitIdleOnRPC(t *testing.T) { overrideNewTimer(t) - enforcer := newTestIdlenessEnforcer() - mgr := newIdlenessManager(enforcer, time.Duration(defaultTestIdleTimeout)) - defer mgr.close() + enforcer := newTestEnforcer() + mgr := NewManager(ManagerOptions{Enforcer: enforcer, Timeout: time.Duration(defaultTestIdleTimeout), Logger: grpclog.Component("test")}) + defer mgr.Close() // Ensure that the channel moves to idle since there are no RPCs. select { @@ -256,12 +268,12 @@ func (s) TestIdlenessManager_Enabled_ExitIdleOnRPC(t *testing.T) { } for i := 0; i < 100; i++ { - // A call to onCallBegin and onCallEnd simulates an RPC. + // A call to OnCallBegin and OnCallEnd simulates an RPC. go func() { - if err := mgr.onCallBegin(); err != nil { - t.Errorf("onCallBegin() failed: %v", err) + if err := mgr.OnCallBegin(); err != nil { + t.Errorf("OnCallBegin() failed: %v", err) } - mgr.onCallEnd() + mgr.OnCallEnd() }() } @@ -282,10 +294,10 @@ func (s) TestIdlenessManager_Enabled_ExitIdleOnRPC(t *testing.T) { } } -type racyIdlenessState int32 +type racyState int32 const ( - stateInital racyIdlenessState = iota + stateInital racyState = iota stateEnteredIdle stateExitedIdle stateActiveRPCs @@ -293,44 +305,44 @@ const ( // racyIdlnessEnforcer is a test idleness enforcer used specifically to test the // race between idle timeout and incoming RPCs. -type racyIdlenessEnforcer struct { - state *racyIdlenessState // Accessed atomically. +type racyEnforcer struct { + state *racyState // Accessed atomically. } -// exitIdleMode sets the internal state to stateExitedIdle. We should only ever +// ExitIdleMode sets the internal state to stateExitedIdle. We should only ever // exit idle when we are currently in idle. -func (ri *racyIdlenessEnforcer) exitIdleMode() error { +func (ri *racyEnforcer) ExitIdleMode() error { if !atomic.CompareAndSwapInt32((*int32)(ri.state), int32(stateEnteredIdle), int32(stateExitedIdle)) { return fmt.Errorf("idleness enforcer asked to exit idle when it did not enter idle earlier") } return nil } -// enterIdleMode attempts to set the internal state to stateEnteredIdle. We should only ever enter idle before RPCs start. -func (ri *racyIdlenessEnforcer) enterIdleMode() error { +// EnterIdleMode attempts to set the internal state to stateEnteredIdle. We should only ever enter idle before RPCs start. +func (ri *racyEnforcer) EnterIdleMode() error { if !atomic.CompareAndSwapInt32((*int32)(ri.state), int32(stateInital), int32(stateEnteredIdle)) { return fmt.Errorf("idleness enforcer asked to enter idle after rpcs started") } return nil } -// TestIdlenessManager_IdleTimeoutRacesWithOnCallBegin tests the case where +// TestManager_IdleTimeoutRacesWithOnCallBegin tests the case where // firing of the idle timeout races with an incoming RPC. The test verifies that // if the timer callback win the race and puts the channel in idle, the RPCs can // kick it out of idle. And if the RPCs win the race and keep the channel // active, then the timer callback should not attempt to put the channel in idle // mode. -func (s) TestIdlenessManager_IdleTimeoutRacesWithOnCallBegin(t *testing.T) { +func (s) TestManager_IdleTimeoutRacesWithOnCallBegin(t *testing.T) { // Run multiple iterations to simulate different possibilities. for i := 0; i < 20; i++ { t.Run(fmt.Sprintf("iteration=%d", i), func(t *testing.T) { - var idlenessState racyIdlenessState - enforcer := &racyIdlenessEnforcer{state: &idlenessState} + var idlenessState racyState + enforcer := &racyEnforcer{state: &idlenessState} // Configure a large idle timeout so that we can control the // race between the timer callback and RPCs. - mgr := newIdlenessManager(enforcer, time.Duration(10*time.Minute)) - defer mgr.close() + mgr := NewManager(ManagerOptions{Enforcer: enforcer, Timeout: time.Duration(10 * time.Minute), Logger: grpclog.Component("test")}) + defer mgr.Close() var wg sync.WaitGroup wg.Add(1) @@ -347,11 +359,11 @@ func (s) TestIdlenessManager_IdleTimeoutRacesWithOnCallBegin(t *testing.T) { // Wait for the configured idle timeout and simulate an RPC to // race with the idle timeout timer callback. <-time.After(defaultTestIdleTimeout / 10) - if err := mgr.onCallBegin(); err != nil { - t.Errorf("onCallBegin() failed: %v", err) + if err := mgr.OnCallBegin(); err != nil { + t.Errorf("OnCallBegin() failed: %v", err) } atomic.StoreInt32((*int32)(&idlenessState), int32(stateActiveRPCs)) - mgr.onCallEnd() + mgr.OnCallEnd() }() } wg.Wait() diff --git a/internal/testutils/state.go b/internal/testutils/state.go new file mode 100644 index 000000000000..246b07a7ea19 --- /dev/null +++ b/internal/testutils/state.go @@ -0,0 +1,85 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package testutils + +import ( + "context" + "testing" + + "google.golang.org/grpc/connectivity" +) + +// A StateChanger reports state changes, e.g. a grpc.ClientConn. +type StateChanger interface { + // Connect begins connecting the StateChanger. + Connect() + // GetState returns the current state of the StateChanger. + GetState() connectivity.State + // WaitForStateChange returns true when the state becomes s, or returns + // false if ctx is canceled first. + WaitForStateChange(ctx context.Context, s connectivity.State) bool +} + +// StayConnected makes sc stay connected by repeatedly calling sc.Connect() +// until the state becomes Shutdown or until ithe context expires. +func StayConnected(ctx context.Context, sc StateChanger) { + for { + state := sc.GetState() + switch state { + case connectivity.Idle: + sc.Connect() + case connectivity.Shutdown: + return + } + if !sc.WaitForStateChange(ctx, state) { + return + } + } +} + +// AwaitState waits for sc to enter stateWant or fatal errors if it doesn't +// happen before ctx expires. +func AwaitState(ctx context.Context, t *testing.T, sc StateChanger, stateWant connectivity.State) { + t.Helper() + for state := sc.GetState(); state != stateWant; state = sc.GetState() { + if !sc.WaitForStateChange(ctx, state) { + t.Fatalf("Timed out waiting for state change. got %v; want %v", state, stateWant) + } + } +} + +// AwaitNotState waits for sc to leave stateDoNotWant or fatal errors if it +// doesn't happen before ctx expires. +func AwaitNotState(ctx context.Context, t *testing.T, sc StateChanger, stateDoNotWant connectivity.State) { + t.Helper() + for state := sc.GetState(); state == stateDoNotWant; state = sc.GetState() { + if !sc.WaitForStateChange(ctx, state) { + t.Fatalf("Timed out waiting for state change. got %v; want NOT %v", state, stateDoNotWant) + } + } +} + +// AwaitNoStateChange expects ctx to be canceled before sc's state leaves +// currState, and fatal errors otherwise. +func AwaitNoStateChange(ctx context.Context, t *testing.T, sc StateChanger, currState connectivity.State) { + t.Helper() + if sc.WaitForStateChange(ctx, currState) { + t.Fatalf("State changed from %q to %q when no state change was expected", currState, sc.GetState()) + } +} diff --git a/stream.go b/stream.go index cf73147bdd22..7e7f23445191 100644 --- a/stream.go +++ b/stream.go @@ -156,10 +156,10 @@ type ClientStream interface { // If none of the above happen, a goroutine and a context will be leaked, and grpc // will not call the optionally-configured stats handler with a stats.End message. func (cc *ClientConn) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) { - if err := cc.idlenessMgr.onCallBegin(); err != nil { + if err := cc.idlenessMgr.OnCallBegin(); err != nil { return nil, err } - defer cc.idlenessMgr.onCallEnd() + defer cc.idlenessMgr.OnCallEnd() // allow interceptor to see all applicable call options, which means those // configured as defaults from dial option as well as per-call options diff --git a/test/balancer_switching_test.go b/test/balancer_switching_test.go index 1a6a89d36294..5decc4d3b83b 100644 --- a/test/balancer_switching_test.go +++ b/test/balancer_switching_test.go @@ -29,7 +29,6 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/balancer/stub" - "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils/fakegrpclb" "google.golang.org/grpc/internal/testutils/pickfirst" @@ -63,7 +62,6 @@ const ( // // Returns a cleanup function to be invoked by the caller. func setupBackendsAndFakeGRPCLB(t *testing.T) ([]*stubserver.StubServer, *fakegrpclb.Server, func()) { - czCleanup := channelz.NewChannelzStorageForTesting() backends, backendsCleanup := startBackendsForBalancerSwitch(t) lbServer, err := fakegrpclb.NewServer(fakegrpclb.ServerParams{ @@ -83,7 +81,6 @@ func setupBackendsAndFakeGRPCLB(t *testing.T) ([]*stubserver.StubServer, *fakegr return backends, lbServer, func() { backendsCleanup() lbServer.Stop() - czCleanupWrapper(czCleanup, t) } } diff --git a/test/channelz_linux_test.go b/test/channelz_linux_test.go index 7d1407323334..d5b691c1d83e 100644 --- a/test/channelz_linux_test.go +++ b/test/channelz_linux_test.go @@ -35,8 +35,6 @@ func (s) TestCZSocketMetricsSocketOption(t *testing.T) { } func testCZSocketMetricsSocketOption(t *testing.T, e env) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) te := newTest(t, e) te.startServer(&testServer{security: e.security}) defer te.tearDown() diff --git a/test/channelz_test.go b/test/channelz_test.go index aa29198b260b..b23acf4bdc1d 100644 --- a/test/channelz_test.go +++ b/test/channelz_test.go @@ -41,6 +41,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" @@ -51,12 +52,6 @@ import ( testpb "google.golang.org/grpc/interop/grpc_testing" ) -func czCleanupWrapper(cleanup func() error, t *testing.T) { - if err := cleanup(); err != nil { - t.Error(err) - } -} - func verifyResultWithDelay(f func() (bool, error)) error { var ok bool var err error @@ -85,28 +80,27 @@ func (s) TestCZServerRegistrationAndDeletion(t *testing.T) { {total: int(channelz.EntryPerPage), start: 0, max: channelz.EntryPerPage - 1, length: channelz.EntryPerPage - 1, end: false}, } - for _, c := range testcases { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) + for i, c := range testcases { + // Reset channelz IDs so `start` is valid. + channelz.IDGen.Reset() + e := tcpClearRREnv te := newTest(t, e) te.startServers(&testServer{security: e.security}, c.total) ss, end := channelz.GetServers(c.start, c.max) if int64(len(ss)) != c.length || end != c.end { - t.Fatalf("GetServers(%d) = %+v (len of which: %d), end: %+v, want len(GetServers(%d)) = %d, end: %+v", c.start, ss, len(ss), end, c.start, c.length, c.end) + t.Fatalf("%d: GetServers(%d) = %+v (len of which: %d), end: %+v, want len(GetServers(%d)) = %d, end: %+v", i, c.start, ss, len(ss), end, c.start, c.length, c.end) } te.tearDown() ss, end = channelz.GetServers(c.start, c.max) if len(ss) != 0 || !end { - t.Fatalf("GetServers(0) = %+v (len of which: %d), end: %+v, want len(GetServers(0)) = 0, end: true", ss, len(ss), end) + t.Fatalf("%d: GetServers(0) = %+v (len of which: %d), end: %+v, want len(GetServers(0)) = 0, end: true", i, ss, len(ss), end) } } } func (s) TestCZGetServer(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) te.startServer(&testServer{security: e.security}) @@ -157,8 +151,9 @@ func (s) TestCZTopChannelRegistrationAndDeletion(t *testing.T) { } for _, c := range testcases { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) + // Reset channelz IDs so `start` is valid. + channelz.IDGen.Reset() + e := tcpClearRREnv te := newTest(t, e) var ccs []*grpc.ClientConn @@ -195,8 +190,6 @@ func (s) TestCZTopChannelRegistrationAndDeletion(t *testing.T) { } func (s) TestCZTopChannelRegistrationAndDeletionWhenDialFail(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) // Make dial fails (due to no transport security specified) _, err := grpc.Dial("fake.addr") if err == nil { @@ -208,8 +201,6 @@ func (s) TestCZTopChannelRegistrationAndDeletionWhenDialFail(t *testing.T) { } func (s) TestCZNestedChannelRegistrationAndDeletion(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv // avoid calling API to set balancer type, which will void service config's change of balancer. e.balancer = "" @@ -256,8 +247,6 @@ func (s) TestCZNestedChannelRegistrationAndDeletion(t *testing.T) { } func (s) TestCZClientSubChannelSocketRegistrationAndDeletion(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv num := 3 // number of backends te := newTest(t, e) @@ -344,8 +333,9 @@ func (s) TestCZServerSocketRegistrationAndDeletion(t *testing.T) { } for _, c := range testcases { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) + // Reset channelz IDs so `start` is valid. + channelz.IDGen.Reset() + e := tcpClearRREnv te := newTest(t, e) te.startServer(&testServer{security: e.security}) @@ -404,8 +394,6 @@ func (s) TestCZServerSocketRegistrationAndDeletion(t *testing.T) { } func (s) TestCZServerListenSocketDeletion(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) s := grpc.NewServer() lis, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -461,8 +449,7 @@ func (s) TestCZRecusivelyDeletionOfEntry(t *testing.T) { // | | // v v // Socket1 Socket2 - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) + topChanID := channelz.RegisterChannel(&dummyChannel{}, nil, "") subChanID1, _ := channelz.RegisterSubChannel(&dummyChannel{}, topChanID, "") subChanID2, _ := channelz.RegisterSubChannel(&dummyChannel{}, topChanID, "") @@ -506,8 +493,6 @@ func (s) TestCZRecusivelyDeletionOfEntry(t *testing.T) { } func (s) TestCZChannelMetrics(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv num := 3 // number of backends te := newTest(t, e) @@ -596,8 +581,6 @@ func (s) TestCZChannelMetrics(t *testing.T) { } func (s) TestCZServerMetrics(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) te.maxServerReceiveMsgSize = newInt(8) @@ -868,8 +851,6 @@ func (s) TestCZClientSocketMetricsStreamsAndMessagesCount(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) te.maxServerReceiveMsgSize = newInt(20) @@ -968,8 +949,6 @@ func (s) TestCZClientSocketMetricsStreamsAndMessagesCount(t *testing.T) { // It is separated from other cases due to setup incompatibly, i.e. max receive // size violation will mask flow control violation. func (s) TestCZClientAndServerSocketMetricsStreamsCountFlowControlRSTStream(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) te.serverInitialWindowSize = 65536 @@ -1052,8 +1031,6 @@ func (s) TestCZClientAndServerSocketMetricsStreamsCountFlowControlRSTStream(t *t } func (s) TestCZClientAndServerSocketMetricsFlowControl(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) // disable BDP @@ -1166,8 +1143,6 @@ func (s) TestCZClientAndServerSocketMetricsFlowControl(t *testing.T) { func (s) TestCZClientSocketMetricsKeepAlive(t *testing.T) { const keepaliveRate = 50 * time.Millisecond - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) defer func(t time.Duration) { internal.KeepaliveMinPingTime = t }(internal.KeepaliveMinPingTime) internal.KeepaliveMinPingTime = keepaliveRate e := tcpClearRREnv @@ -1187,7 +1162,7 @@ func (s) TestCZClientSocketMetricsKeepAlive(t *testing.T) { cc := te.clientConn() // Dial the server ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) start := time.Now() // Wait for at least two keepalives to be able to occur. time.Sleep(2 * keepaliveRate) @@ -1226,8 +1201,6 @@ func (s) TestCZClientSocketMetricsKeepAlive(t *testing.T) { } func (s) TestCZServerSocketMetricsStreamsAndMessagesCount(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) te.maxServerReceiveMsgSize = newInt(20) @@ -1290,8 +1263,6 @@ func (s) TestCZServerSocketMetricsKeepAlive(t *testing.T) { defer func(t time.Duration) { internal.KeepaliveMinServerPingTime = t }(internal.KeepaliveMinServerPingTime) internal.KeepaliveMinServerPingTime = 50 * time.Millisecond - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) // We setup the server keepalive parameters to send one keepalive every @@ -1311,7 +1282,7 @@ func (s) TestCZServerSocketMetricsKeepAlive(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) // Allow about 5 pings to happen (250ms/50ms). time.Sleep(255 * time.Millisecond) @@ -1360,8 +1331,6 @@ var cipherSuites = []string{ } func (s) TestCZSocketGetSecurityValueTLS(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpTLSRREnv te := newTest(t, e) te.startServer(&testServer{security: e.security}) @@ -1410,9 +1379,6 @@ func (s) TestCZSocketGetSecurityValueTLS(t *testing.T) { } func (s) TestCZChannelTraceCreationDeletion(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) - e := tcpClearRREnv // avoid calling API to set balancer type, which will void service config's change of balancer. e.balancer = "" @@ -1493,8 +1459,6 @@ func (s) TestCZChannelTraceCreationDeletion(t *testing.T) { } func (s) TestCZSubChannelTraceCreationDeletion(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) te.startServer(&testServer{security: e.security}) @@ -1543,9 +1507,9 @@ func (s) TestCZSubChannelTraceCreationDeletion(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, te.cc, connectivity.Ready) + testutils.AwaitState(ctx, t, te.cc, connectivity.Ready) r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: "fake address"}}}) - awaitNotState(ctx, t, te.cc, connectivity.Ready) + testutils.AwaitNotState(ctx, t, te.cc, connectivity.Ready) if err := verifyResultWithDelay(func() (bool, error) { tcs, _ := channelz.GetTopChannels(0, 0) @@ -1578,8 +1542,6 @@ func (s) TestCZSubChannelTraceCreationDeletion(t *testing.T) { } func (s) TestCZChannelAddressResolutionChange(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv e.balancer = "" te := newTest(t, e) @@ -1684,8 +1646,6 @@ func (s) TestCZChannelAddressResolutionChange(t *testing.T) { } func (s) TestCZSubChannelPickedNewAddress(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv e.balancer = "" te := newTest(t, e) @@ -1756,8 +1716,6 @@ func (s) TestCZSubChannelPickedNewAddress(t *testing.T) { } func (s) TestCZSubChannelConnectivityState(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) te.startServer(&testServer{security: e.security}) @@ -1855,8 +1813,6 @@ func (s) TestCZSubChannelConnectivityState(t *testing.T) { } func (s) TestCZChannelConnectivityState(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) te.startServer(&testServer{security: e.security}) @@ -1914,8 +1870,6 @@ func (s) TestCZChannelConnectivityState(t *testing.T) { } func (s) TestCZTraceOverwriteChannelDeletion(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv e.balancer = "" te := newTest(t, e) @@ -1984,8 +1938,6 @@ func (s) TestCZTraceOverwriteChannelDeletion(t *testing.T) { } func (s) TestCZTraceOverwriteSubChannelDeletion(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) channelz.SetMaxTraceEntry(1) @@ -2017,9 +1969,9 @@ func (s) TestCZTraceOverwriteSubChannelDeletion(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, te.cc, connectivity.Ready) + testutils.AwaitState(ctx, t, te.cc, connectivity.Ready) r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: "fake address"}}}) - awaitNotState(ctx, t, te.cc, connectivity.Ready) + testutils.AwaitNotState(ctx, t, te.cc, connectivity.Ready) // verify that the subchannel no longer exist due to trace referencing it got overwritten. if err := verifyResultWithDelay(func() (bool, error) { @@ -2034,8 +1986,6 @@ func (s) TestCZTraceOverwriteSubChannelDeletion(t *testing.T) { } func (s) TestCZTraceTopChannelDeletionTraceClear(t *testing.T) { - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) e := tcpClearRREnv te := newTest(t, e) te.startServer(&testServer{security: e.security}) diff --git a/test/clientconn_state_transition_test.go b/test/clientconn_state_transition_test.go index 00660168f1b5..6e9bfb37289d 100644 --- a/test/clientconn_state_transition_test.go +++ b/test/clientconn_state_transition_test.go @@ -175,7 +175,7 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - go stayConnected(ctx, client) + go testutils.StayConnected(ctx, client) stateNotifications := testBalancerBuilder.nextStateNotifier() for i := 0; i < len(want); i++ { @@ -242,7 +242,7 @@ func (s) TestStateTransitions_ReadyToConnecting(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - go stayConnected(ctx, client) + go testutils.StayConnected(ctx, client) stateNotifications := testBalancerBuilder.nextStateNotifier() @@ -417,7 +417,7 @@ func (s) TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - go stayConnected(ctx, client) + go testutils.StayConnected(ctx, client) stateNotifications := testBalancerBuilder.nextStateNotifier() want := []connectivity.State{ @@ -509,48 +509,6 @@ func keepReading(conn net.Conn) { } } -// stayConnected makes cc stay connected by repeatedly calling cc.Connect() -// until the state becomes Shutdown or until ithe context expires. -func stayConnected(ctx context.Context, cc *grpc.ClientConn) { - for { - state := cc.GetState() - switch state { - case connectivity.Idle: - cc.Connect() - case connectivity.Shutdown: - return - } - if !cc.WaitForStateChange(ctx, state) { - return - } - } -} - -func awaitState(ctx context.Context, t *testing.T, cc *grpc.ClientConn, stateWant connectivity.State) { - t.Helper() - for state := cc.GetState(); state != stateWant; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Fatalf("timed out waiting for state change. got %v; want %v", state, stateWant) - } - } -} - -func awaitNotState(ctx context.Context, t *testing.T, cc *grpc.ClientConn, stateDoNotWant connectivity.State) { - t.Helper() - for state := cc.GetState(); state == stateDoNotWant; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Fatalf("timed out waiting for state change. got %v; want NOT %v", state, stateDoNotWant) - } - } -} - -func awaitNoStateChange(ctx context.Context, t *testing.T, cc *grpc.ClientConn, currState connectivity.State) { - t.Helper() - if cc.WaitForStateChange(ctx, currState) { - t.Fatalf("State changed from %q to %q when no state change was expected", currState, cc.GetState()) - } -} - type funcConnectivityStateSubscriber struct { onMsg func(connectivity.State) } diff --git a/test/clientconn_test.go b/test/clientconn_test.go index bdbe81d03040..4432701fb3d1 100644 --- a/test/clientconn_test.go +++ b/test/clientconn_test.go @@ -40,10 +40,6 @@ import ( // blocking. The test verifies that closing the ClientConn unblocks the RPC with // the expected error code. func (s) TestClientConnClose_WithPendingRPC(t *testing.T) { - // Initialize channelz. Used to determine pending RPC count. - czCleanup := channelz.NewChannelzStorageForTesting() - defer czCleanupWrapper(czCleanup, t) - r := manual.NewBuilderWithScheme("whatever") cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r)) if err != nil { diff --git a/test/creds_test.go b/test/creds_test.go index 28e16eb543fc..ef73fe936c80 100644 --- a/test/creds_test.go +++ b/test/creds_test.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" @@ -445,7 +446,7 @@ func (s) TestCredsHandshakeAuthority(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) if cred.got != testAuthority { t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) @@ -477,7 +478,7 @@ func (s) TestCredsHandshakeServerNameAuthority(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) if cred.got != testServerName { t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) diff --git a/test/end2end_test.go b/test/end2end_test.go index 948382d04f99..9d1b4a78b0ce 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -696,6 +696,7 @@ func (w wrapHS) GracefulStop() { func (w wrapHS) Stop() { w.s.Close() + w.s.Handler.(*grpc.Server).Stop() } func (te *test) startServerWithConnControl(ts testgrpc.TestServiceServer) *listenerWrapper { @@ -984,9 +985,9 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { } // Wait for the client to report READY, stop the server, then wait for the // client to notice the connection is gone. - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) te.srv.Stop() - awaitNotState(ctx, t, cc, connectivity.Ready) + testutils.AwaitNotState(ctx, t, cc, connectivity.Ready) ctx, cancel = context.WithTimeout(ctx, defaultTestShortTimeout) defer cancel() if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Code(err) != codes.DeadlineExceeded { @@ -3031,7 +3032,9 @@ func (s) TestTransparentRetry(t *testing.T) { func (s) TestCancel(t *testing.T) { for _, e := range listTestEnv() { - testCancel(t, e) + t.Run(e.name, func(t *testing.T) { + testCancel(t, e) + }) } } @@ -4855,7 +4858,7 @@ func testWaitForReadyConnection(t *testing.T, e env) { tc := testgrpc.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) // Make a fail-fast RPC. if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Fatalf("TestService/EmptyCall(_,_) = _, %v, want _, nil", err) diff --git a/test/goaway_test.go b/test/goaway_test.go index 1bab495f64e6..2a8ff0bfcc04 100644 --- a/test/goaway_test.go +++ b/test/goaway_test.go @@ -595,7 +595,7 @@ func (s) TestGoAwayThenClose(t *testing.T) { client := testgrpc.NewTestServiceClient(cc) t.Log("Waiting for the ClientConn to enter READY state.") - awaitState(ctx, t, cc, connectivity.Ready) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) // We make a streaming RPC and do an one-message-round-trip to make sure // it's created on connection 1. @@ -618,7 +618,7 @@ func (s) TestGoAwayThenClose(t *testing.T) { go s1.GracefulStop() t.Log("Waiting for the ClientConn to enter IDLE state.") - awaitState(ctx, t, cc, connectivity.Idle) + testutils.AwaitState(ctx, t, cc, connectivity.Idle) t.Log("Performing another RPC to create a connection to server 2.") if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { diff --git a/test/healthcheck_test.go b/test/healthcheck_test.go index c5a16d00ae76..2b5b5a82d93c 100644 --- a/test/healthcheck_test.go +++ b/test/healthcheck_test.go @@ -35,6 +35,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/status" @@ -212,33 +213,33 @@ func (s) TestHealthCheckWatchStateChange(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitNotState(ctx, t, cc, connectivity.Idle) - awaitNotState(ctx, t, cc, connectivity.Connecting) - awaitState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitNotState(ctx, t, cc, connectivity.Idle) + testutils.AwaitNotState(ctx, t, cc, connectivity.Connecting) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) if s := cc.GetState(); s != connectivity.TransientFailure { t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s) } ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING) - awaitNotState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitNotState(ctx, t, cc, connectivity.TransientFailure) if s := cc.GetState(); s != connectivity.Ready { t.Fatalf("ClientConn is in %v state, want READY", s) } ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVICE_UNKNOWN) - awaitNotState(ctx, t, cc, connectivity.Ready) + testutils.AwaitNotState(ctx, t, cc, connectivity.Ready) if s := cc.GetState(); s != connectivity.TransientFailure { t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s) } ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING) - awaitNotState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitNotState(ctx, t, cc, connectivity.TransientFailure) if s := cc.GetState(); s != connectivity.Ready { t.Fatalf("ClientConn is in %v state, want READY", s) } ts.SetServingStatus("foo", healthpb.HealthCheckResponse_UNKNOWN) - awaitNotState(ctx, t, cc, connectivity.Ready) + testutils.AwaitNotState(ctx, t, cc, connectivity.Ready) if s := cc.GetState(); s != connectivity.TransientFailure { t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s) } @@ -267,8 +268,8 @@ func (s) TestHealthCheckHealthServerNotRegistered(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitNotState(ctx, t, cc, connectivity.Idle) - awaitNotState(ctx, t, cc, connectivity.Connecting) + testutils.AwaitNotState(ctx, t, cc, connectivity.Idle) + testutils.AwaitNotState(ctx, t, cc, connectivity.Connecting) if s := cc.GetState(); s != connectivity.Ready { t.Fatalf("ClientConn is in %v state, want READY", s) } diff --git a/test/pickfirst_test.go b/test/pickfirst_test.go index a7496e854125..fc9e1c48f352 100644 --- a/test/pickfirst_test.go +++ b/test/pickfirst_test.go @@ -55,10 +55,6 @@ const pickFirstServiceConfig = `{"loadBalancingConfig": [{"pick_first":{}}]}` func setupPickFirst(t *testing.T, backendCount int, opts ...grpc.DialOption) (*grpc.ClientConn, *manual.Resolver, []*stubserver.StubServer) { t.Helper() - // Initialize channelz. Used to determine pending RPC count. - czCleanup := channelz.NewChannelzStorageForTesting() - t.Cleanup(func() { czCleanupWrapper(czCleanup, t) }) - r := manual.NewBuilderWithScheme("whatever") backends := make([]*stubserver.StubServer, backendCount) @@ -259,7 +255,7 @@ func (s) TestPickFirst_NewAddressWhileBlocking(t *testing.T) { // Send a resolver update with no addresses. This should push the channel into // TransientFailure. r.UpdateState(resolver.State{}) - awaitState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) doneCh := make(chan struct{}) client := testgrpc.NewTestServiceClient(cc) @@ -355,7 +351,7 @@ func (s) TestPickFirst_StickyTransientFailure(t *testing.T) { } t.Cleanup(func() { cc.Close() }) - awaitState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) // Spawn a goroutine to ensure that the channel stays in TransientFailure. // The call to cc.WaitForStateChange will return false when the main @@ -423,7 +419,7 @@ func (s) TestPickFirst_ShuffleAddressList(t *testing.T) { // Send a resolver update with no addresses. This should push the channel // into TransientFailure. r.UpdateState(resolver.State{}) - awaitState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) // Send the same config as last time with shuffling enabled. Since we are // not connected to backend 0, we should connect to backend 1. @@ -578,10 +574,6 @@ func (s) TestPickFirst_ParseConfig_Failure(t *testing.T) { func setupPickFirstWithListenerWrapper(t *testing.T, backendCount int, opts ...grpc.DialOption) (*grpc.ClientConn, *manual.Resolver, []*stubserver.StubServer, []*testutils.ListenerWrapper) { t.Helper() - // Initialize channelz. Used to determine pending RPC count. - czCleanup := channelz.NewChannelzStorageForTesting() - t.Cleanup(func() { czCleanupWrapper(czCleanup, t) }) - backends := make([]*stubserver.StubServer, backendCount) addrs := make([]resolver.Address, backendCount) listeners := make([]*testutils.ListenerWrapper, backendCount) @@ -803,7 +795,7 @@ func (s) TestPickFirst_ResolverError_NoPreviousUpdate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) client := testgrpc.NewTestServiceClient(cc) _, err := client.EmptyCall(ctx, &testpb.Empty{}) @@ -888,7 +880,7 @@ func (s) TestPickFirst_ResolverError_WithPreviousUpdate_Connecting(t *testing.T) addrs := []resolver.Address{{Addr: lis.Addr().String()}} r.UpdateState(resolver.State{Addresses: addrs}) - awaitState(ctx, t, cc, connectivity.Connecting) + testutils.AwaitState(ctx, t, cc, connectivity.Connecting) nrErr := errors.New("error from name resolver") r.ReportError(nrErr) @@ -905,7 +897,7 @@ func (s) TestPickFirst_ResolverError_WithPreviousUpdate_Connecting(t *testing.T) // Closing this channel leads to closing of the connection by our listener. // gRPC should see this as a connection error. close(waitForConnecting) - awaitState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) checkForConnectionError(ctx, t, cc) } @@ -945,7 +937,7 @@ func (s) TestPickFirst_ResolverError_WithPreviousUpdate_TransientFailure(t *test r.UpdateState(resolver.State{Addresses: addrs}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - awaitState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) checkForConnectionError(ctx, t, cc) // An error from the name resolver should result in RPCs failing with that diff --git a/test/roundrobin_test.go b/test/roundrobin_test.go index 15190329aff8..b4b17895b053 100644 --- a/test/roundrobin_test.go +++ b/test/roundrobin_test.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/internal/channelz" imetadata "google.golang.org/grpc/internal/metadata" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" rrutil "google.golang.org/grpc/internal/testutils/roundrobin" "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" @@ -46,10 +47,6 @@ const rrServiceConfig = `{"loadBalancingConfig": [{"round_robin":{}}]}` func testRoundRobinBasic(ctx context.Context, t *testing.T, opts ...grpc.DialOption) (*grpc.ClientConn, *manual.Resolver, []*stubserver.StubServer) { t.Helper() - // Initialize channelz. Used to determine pending RPC count. - czCleanup := channelz.NewChannelzStorageForTesting() - t.Cleanup(func() { czCleanupWrapper(czCleanup, t) }) - r := manual.NewBuilderWithScheme("whatever") const backendCount = 5 @@ -119,7 +116,7 @@ func (s) TestRoundRobin_AddressesRemoved(t *testing.T) { // Send a resolver update with no addresses. This should push the channel into // TransientFailure. r.UpdateState(resolver.State{Addresses: []resolver.Address{}}) - awaitState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) const msgWant = "produced zero addresses" client := testgrpc.NewTestServiceClient(cc) @@ -141,7 +138,7 @@ func (s) TestRoundRobin_NewAddressWhileBlocking(t *testing.T) { // Send a resolver update with no addresses. This should push the channel into // TransientFailure. r.UpdateState(resolver.State{Addresses: []resolver.Address{}}) - awaitState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) client := testgrpc.NewTestServiceClient(cc) doneCh := make(chan struct{}) @@ -221,7 +218,7 @@ func (s) TestRoundRobin_AllServersDown(t *testing.T) { b.Stop() } - awaitState(ctx, t, cc, connectivity.TransientFailure) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) // Failfast RPCs should fail with Unavailable. client := testgrpc.NewTestServiceClient(cc) diff --git a/test/subconn_test.go b/test/subconn_test.go index e8c8d936a9fb..cd2ac5a5432d 100644 --- a/test/subconn_test.go +++ b/test/subconn_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" testpb "google.golang.org/grpc/interop/grpc_testing" "google.golang.org/grpc/resolver" ) @@ -115,7 +116,7 @@ func (s) TestSubConnEmpty(t *testing.T) { t.Log("Removing addresses from resolver and SubConn") ss.R.UpdateState(resolver.State{Addresses: []resolver.Address{}}) - awaitState(ctx, t, ss.CC, connectivity.TransientFailure) + testutils.AwaitState(ctx, t, ss.CC, connectivity.TransientFailure) t.Log("Re-adding addresses to resolver and SubConn") ss.R.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address}}}) diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go index 7a358b1fcb5d..11a7d33ed523 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go @@ -388,11 +388,7 @@ func (s) TestSecurityConfigNotFoundInBootstrap(t *testing.T) { t.Fatal(err) } - for state := cc.GetState(); state != connectivity.TransientFailure; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Fatal("Timed out waiting for channel to enter TRANSIENT_FAILURE") - } - } + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) } // A ceritificate provider builder that returns a nil Provider from the starter @@ -456,11 +452,7 @@ func (s) TestCertproviderStoreError(t *testing.T) { t.Fatal(err) } - for state := cc.GetState(); state != connectivity.TransientFailure; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Fatal("Timed out waiting for channel to enter TRANSIENT_FAILURE") - } - } + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) } // Tests the case where the cds LB policy receives security configuration as @@ -545,11 +537,7 @@ func (s) TestSecurityConfigUpdate_BadToGood(t *testing.T) { t.Fatal(err) } - for state := cc.GetState(); state != connectivity.TransientFailure; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Fatal("Timed out waiting for channel to enter TRANSIENT_FAILURE") - } - } + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) // Update the management server with a Cluster resource that contains a // certificate provider instance that is present in the bootstrap diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go index 050449c9643e..1d02c4810912 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go @@ -665,12 +665,7 @@ func (s) TestClusterUpdate_Failure(t *testing.T) { t.Fatal("Watch for cluster resource is cancelled when not expected to") } - // Ensure that the ClientConn moves to TransientFailure. - for state := cc.GetState(); state != connectivity.TransientFailure; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Fatalf("Timed out waiting for state change. got %v; want %v", state, connectivity.TransientFailure) - } - } + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) // Ensure that the NACK error is propagated to the RPC caller. const wantClusterNACKErr = "unsupported config_source_specifier" @@ -758,12 +753,8 @@ func (s) TestClusterUpdate_Failure(t *testing.T) { t.Fatal("Watch for cluster resource is cancelled when not expected to") } - // Ensure that the ClientConn moves to TransientFailure. - for state := cc.GetState(); state != connectivity.TransientFailure; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Fatalf("Timed out waiting for state change. got %v; want %v", state, connectivity.TransientFailure) - } - } + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) + // Ensure RPC fails with Unavailable. The actual error message depends on // the picker returned from the priority LB policy, and therefore not // checking for it here. @@ -801,12 +792,7 @@ func (s) TestResolverError(t *testing.T) { resolverErr := errors.New("resolver-error-not-a-resource-not-found-error") r.ReportError(resolverErr) - // Ensure that the ClientConn moves to TransientFailure. - for state := cc.GetState(); state != connectivity.TransientFailure; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Fatalf("Timed out waiting for state change. got %v; want %v", state, connectivity.TransientFailure) - } - } + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) // Drain the resolver error channel. select { @@ -903,12 +889,7 @@ func (s) TestResolverError(t *testing.T) { t.Fatal("Timeout when waiting for resolver error to be pushed to the child policy") } - // Ensure that the ClientConn moves to TransientFailure. - for state := cc.GetState(); state != connectivity.TransientFailure; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Fatalf("Timed out waiting for state change. got %v; want %v", state, connectivity.TransientFailure) - } - } + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) // Ensure RPC fails with Unavailable. The actual error message depends on // the picker returned from the priority LB policy, and therefore not diff --git a/xds/internal/balancer/clusterresolver/e2e_test/balancer_test.go b/xds/internal/balancer/clusterresolver/e2e_test/balancer_test.go index 073775ac9da0..4dbec18229f4 100644 --- a/xds/internal/balancer/clusterresolver/e2e_test/balancer_test.go +++ b/xds/internal/balancer/clusterresolver/e2e_test/balancer_test.go @@ -277,12 +277,7 @@ func (s) TestErrorFromParentLB_ResourceNotFound(t *testing.T) { t.Fatalf("RPCs did not fail after removal of Cluster resource") } - // Ensure that the ClientConn moves to TransientFailure. - for state := cc.GetState(); state != connectivity.TransientFailure; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Fatalf("Timed out waiting for state change. got %v; want %v", state, connectivity.TransientFailure) - } - } + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) // Configure cluster and endpoints resources in the management server. resources = e2e.UpdateOptions{ diff --git a/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go b/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go index 4105e3550b7c..5eb8ffd16bd3 100644 --- a/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go +++ b/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go @@ -110,20 +110,14 @@ func (s) TestRingHash_ReconnectToMoveOutOfTransientFailure(t *testing.T) { // which will lead to the channel eventually moving to IDLE. The ring_hash // LB policy is not expected to reconnect by itself at this point. lis.Stop() - for state := cc.GetState(); state != connectivity.Idle && cc.WaitForStateChange(ctx, state); state = cc.GetState() { - } - if err := ctx.Err(); err != nil { - t.Fatalf("Timeout waiting for channel to reach %q after server shutdown: %v", connectivity.Idle, err) - } + + testutils.AwaitState(ctx, t, cc, connectivity.Idle) // Make an RPC to get the ring_hash LB policy to reconnect and thereby move // to TRANSIENT_FAILURE upon connection failure. client.EmptyCall(ctx, &testpb.Empty{}) - for state := cc.GetState(); state != connectivity.TransientFailure && cc.WaitForStateChange(ctx, state); state = cc.GetState() { - } - if err := ctx.Err(); err != nil { - t.Fatalf("Timeout waiting for channel to reach %q after server shutdown: %v", connectivity.TransientFailure, err) - } + + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) // An RPC at this point is expected to fail. if _, err = client.EmptyCall(ctx, &testpb.Empty{}); err == nil {