Skip to content

Commit

Permalink
[PSM Interop] Update logic to detect failed ADS channels (grpc#35280)
Browse files Browse the repository at this point in the history
<!--

If you know who should review your pull request, please assign it to that
person, otherwise the pull request would get assigned randomly.

If your pull request is for a specific language, please add the appropriate
lang label.

-->

Closes grpc#35280

COPYBARA_INTEGRATE_REVIEW=grpc#35280 from yashykt:UpdateInteropScriptForFindingAdsChannel db21338
PiperOrigin-RevId: 591090750
  • Loading branch information
yashykt authored and copybara-github committed Dec 15, 2023
1 parent a5a25c7 commit 1c96d53
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 26 deletions.
1 change: 1 addition & 0 deletions tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

# Type aliases
Message = google.protobuf.message.Message
RpcError = grpc.RpcError


class GrpcClientHelper:
Expand Down
22 changes: 22 additions & 0 deletions tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# Type aliases
# Channel
Channel = channelz_pb2.Channel
ChannelData = channelz_pb2.ChannelData
ChannelConnectivityState = channelz_pb2.ChannelConnectivityState
ChannelState = ChannelConnectivityState.State # pylint: disable=no-member
_GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest
Expand Down Expand Up @@ -109,6 +110,7 @@ def channel_repr(channel: Channel) -> str:
result += f" target={channel.data.target}"
result += (
f" call_started={channel.data.calls_started}"
+ f" calls_succeeded={channel.data.calls_succeeded}"
+ f" calls_failed={channel.data.calls_failed}"
)
result += f" state={ChannelState.Name(channel.data.state.state)}>"
Expand Down Expand Up @@ -170,6 +172,26 @@ def list_channels(self, **kwargs) -> Iterator[Channel]:
start = max(start, channel.ref.channel_id)
yield channel

def get_channel(self, channel_id, **kwargs) -> Channel:
"""Return a single Channel, otherwise raises RpcError."""
response: channelz_pb2.GetChannelResponse
try:
response = self.call_unary_with_deadline(
rpc="GetChannel",
req=channelz_pb2.GetChannelRequest(channel_id=channel_id),
**kwargs,
)
return response.channel
except grpc.RpcError as err:
if isinstance(err, grpc.Call):
# Translate NOT_FOUND into GrpcApp.NotFound.
if err.code() is grpc.StatusCode.NOT_FOUND:
raise framework.rpc.grpc.GrpcApp.NotFound(
f"Channel with channel_id {channel_id} not found",
)

raise

def list_servers(self, **kwargs) -> Iterator[Server]:
"""Iterate over all pages of all servers that exist in the process."""
start: int = -1
Expand Down
110 changes: 84 additions & 26 deletions tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import datetime
import functools
import logging
import time
from typing import Iterable, List, Optional

import framework.errors
Expand All @@ -36,6 +37,7 @@
)
_ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
_ChannelzChannel = grpc_channelz.Channel
_ChannelzChannelData = grpc_channelz.ChannelData
_ChannelzChannelState = grpc_channelz.ChannelState
_ChannelzSubchannel = grpc_channelz.Subchannel
_ChannelzSocket = grpc_channelz.Socket
Expand Down Expand Up @@ -280,7 +282,7 @@ def wait_for_xds_channel_active(
)

logger.info(
"[%s] ADS: Waiting for successful calls to xDS control plane to %s",
"[%s] ADS: Waiting for active calls to xDS control plane to %s",
self.hostname,
xds_server_uri,
)
Expand All @@ -290,7 +292,7 @@ def wait_for_xds_channel_active(
rpc_deadline=rpc_deadline,
)
logger.info(
"[%s] ADS: Detected successful calls to xDS control plane %s",
"[%s] ADS: Detected active calls to xDS control plane %s",
self.hostname,
xds_server_uri,
)
Expand All @@ -306,29 +308,41 @@ def find_active_xds_channel(
if rpc_deadline is not None:
rpc_params["deadline_sec"] = rpc_deadline.total_seconds()

for channel in self.get_server_channels(xds_server_uri, **rpc_params):
for channel in self.find_channels(xds_server_uri, **rpc_params):
logger.info(
"[%s] xDS control plane channel: %s",
self.hostname,
_ChannelzServiceClient.channel_repr(channel),
)

try:
channel = self.check_channel_successful_calls(
channel_upd = self.check_channel_in_flight_calls(
channel, **rpc_params
)
logger.info(
"[%s] Detected successful calls to xDS control plane %s,"
"[%s] Detected active calls to xDS control plane %s,"
" channel: %s",
self.hostname,
xds_server_uri,
_ChannelzServiceClient.channel_repr(channel),
_ChannelzServiceClient.channel_repr(channel_upd),
)
return channel_upd
except self.NotFound:
# Otherwise, keep searching.
# Continue checking other channels to the same target on
# not found.
continue

return channel
except framework.rpc.grpc.RpcError as err:
# Logged at 'info' and not at 'warning' because this method is
# expected to be called in a retryer. If this error eventually
# causes the retryer to fail, it will be logged fully at 'error'
logger.info(
"[%s] Unexpected error while checking xDS control plane"
" channel %s: %r",
self.hostname,
_ChannelzServiceClient.channel_repr(channel),
err,
)
raise

raise self.ChannelNotActive(
f"[{self.hostname}] Client has no"
Expand All @@ -351,7 +365,7 @@ def find_server_channel_with_state(
expected_state_name: str = _ChannelzChannelState.Name(expected_state)
target: str = self.server_target

for channel in self.get_server_channels(target, **rpc_params):
for channel in self.find_channels(target, **rpc_params):
channel_state: _ChannelzChannelState = channel.data.state.state
logger.info(
"[%s] Server channel: %s",
Expand Down Expand Up @@ -386,10 +400,12 @@ def find_server_channel_with_state(
expected_state=expected_state,
)

def get_server_channels(
self, server_target: str, **kwargs
def find_channels(
self,
target: str,
**rpc_params,
) -> Iterable[_ChannelzChannel]:
return self.channelz.find_channels_for_target(server_target, **kwargs)
return self.channelz.find_channels_for_target(target, **rpc_params)

def find_subchannel_with_state(
self, channel: _ChannelzChannel, state: _ChannelzChannelState, **kwargs
Expand Down Expand Up @@ -419,23 +435,65 @@ def find_subchannels_with_state(
subchannels.append(subchannel)
return subchannels

def check_channel_successful_calls(
self, channel: _ChannelzChannel, **kwargs
) -> _ChannelzChannel:
"""Checks if the channel has any successful calls.
We consider the channel is active if channel is in READY state and calls_started is
greater than calls_failed.
def check_channel_in_flight_calls(
self,
channel: _ChannelzChannel,
*,
wait_between_checks: Optional[_timedelta] = None,
**rpc_params,
) -> Optional[_ChannelzChannel]:
"""Checks if the channel has calls that started, but didn't complete.
We consider the channel is active if channel is in READY state and
calls_started is greater than calls_failed.
This method address race where a call to the xDS control plane server
has just started and a channelz request comes in before the call has
had a chance to fail.
With channels to the xDS control plane, the channel can be READY but the
calls could be failing to initialize, f.e. due to a failure to fetch
OAUTH2 token. To increase the confidence that we have a valid channel
with working OAUTH2 tokens, we check whether the channel is in a READY
state with active calls twice with an interval of 2 seconds between the
two attempts. If the OAUTH2 token is not valid, the call would fail and
be caught in either the first attempt, or the second attempt. It is
possible that between the two attempts, a call fails and a new call is
started, so we also test for equality between the started calls of the
two channelz results.
There still exists a possibility that a call fails on fetching OAUTH2
token after 2 seconds (maybe because there is a slowdown in the
system.) If such a case is observed, consider increasing the interval
from 2 seconds to 5 seconds.
Returns updated channel on success, or None on failure.
"""
if not self.calc_calls_in_flight(channel):
return None

if not wait_between_checks:
wait_between_checks = _timedelta(seconds=2)

# Load the channel second time after the timeout.
time.sleep(wait_between_checks.total_seconds())
channel_upd: _ChannelzChannel = self.channelz.get_channel(
channel.ref.channel_id, **rpc_params
)
if (
channel.data.state.state is _ChannelzChannelState.READY
and channel.data.calls_started > channel.data.calls_failed
not self.calc_calls_in_flight(channel_upd)
or channel.data.calls_started != channel_upd.data.calls_started
):
return channel
return None
return channel_upd

raise self.NotFound(
f"[{self.hostname}] Not found successful calls over the channel."
)
@classmethod
def calc_calls_in_flight(cls, channel: _ChannelzChannel) -> int:
cdata: _ChannelzChannelData = channel.data
if cdata.state.state is not _ChannelzChannelState.READY:
return 0

return cdata.calls_started - cdata.calls_succeeded - cdata.calls_failed

class ChannelNotFound(framework.rpc.grpc.GrpcApp.NotFound):
"""Channel with expected status not found"""
Expand Down

0 comments on commit 1c96d53

Please sign in to comment.