Skip to content

Commit

Permalink
ADS: Fix the issue with channel not refreshing on the second check
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiitk committed Dec 14, 2023
1 parent 1bdf127 commit f163b38
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 51 deletions.
1 change: 1 addition & 0 deletions tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

# Type aliases
Message = google.protobuf.message.Message
RpcError = grpc.RpcError


class GrpcClientHelper:
Expand Down
22 changes: 22 additions & 0 deletions tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# Type aliases
# Channel
Channel = channelz_pb2.Channel
ChannelData = channelz_pb2.ChannelData
ChannelConnectivityState = channelz_pb2.ChannelConnectivityState
ChannelState = ChannelConnectivityState.State # pylint: disable=no-member
_GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest
Expand Down Expand Up @@ -109,6 +110,7 @@ def channel_repr(channel: Channel) -> str:
result += f" target={channel.data.target}"
result += (
f" call_started={channel.data.calls_started}"
+ f" calls_succeeded={channel.data.calls_succeeded}"
+ f" calls_failed={channel.data.calls_failed}"
)
result += f" state={ChannelState.Name(channel.data.state.state)}>"
Expand Down Expand Up @@ -170,6 +172,26 @@ def list_channels(self, **kwargs) -> Iterator[Channel]:
start = max(start, channel.ref.channel_id)
yield channel

def get_channel(self, channel_id, **kwargs) -> Channel:
"""Return a single Channel, otherwise raises RpcError."""
response: channelz_pb2.GetChannelResponse
try:
response = self.call_unary_with_deadline(
rpc="GetChannel",
req=channelz_pb2.GetChannelRequest(channel_id=channel_id),
**kwargs,
)
return response.channel
except grpc.RpcError as err:
if isinstance(err, grpc.Call):
# Translate NOT_FOUND into GrpcApp.NotFound.
if err.code() is grpc.StatusCode.NOT_FOUND:
raise framework.rpc.grpc.GrpcApp.NotFound(
f"Channel with channel_id {channel_id} not found",
)

raise

def list_servers(self, **kwargs) -> Iterator[Server]:
"""Iterate over all pages of all servers that exist in the process."""
start: int = -1
Expand Down
124 changes: 73 additions & 51 deletions tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
_ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
_ChannelzChannel = grpc_channelz.Channel
_ChannelzChannelData = grpc_channelz.ChannelData
_ChannelzChannelState = grpc_channelz.ChannelState
_ChannelzSubchannel = grpc_channelz.Subchannel
_ChannelzSocket = grpc_channelz.Socket
Expand Down Expand Up @@ -307,58 +308,35 @@ def find_active_xds_channel(
if rpc_deadline is not None:
rpc_params["deadline_sec"] = rpc_deadline.total_seconds()

for channel in self.get_server_channels(xds_server_uri, **rpc_params):
for channel in self.find_channels(xds_server_uri, **rpc_params):
logger.info(
"[%s] xDS control plane channel: %s",
self.hostname,
_ChannelzServiceClient.channel_repr(channel),
)

try:
channel_first_attempt = self.check_channel_successful_calls(
channel_upd = self.check_channel_in_flight_calls(
channel, **rpc_params
)
# Address race where a call to the xDS control plane server has
# just started and a channelz request comes in before the call
# has had a chance to fail.
# With channels to the xDS control plane, the channel can be
# READY but the calls could be failing due to failure to fetch
# OAUTH2 token. To increase the confidence that we have a valid
# channel with working OAUTH2 tokens, we check whether the
# channel is in a READY state with active calls twice with an
# interval of 2 seconds between the two attempts. If the OAUTH2
# token is not valid, the call would fail and be caught in
# either the first attempt, or the second attempt. It is
# possible that between the two attempts, a call fails and a new
# call is started, so we also test for equality between the
# started calls of the two channelz results.
# There still exists a possibility that a call fails on fetching
# OAUTH2 token after 2 seconds (maybe because there is a
# slowdown in the system.) If such a case is observed, consider
# increasing the interval from 2 seconds to 5 seconds.
time.sleep(2)
channel_second_attempt = self.check_channel_successful_calls(
channel, **rpc_params
)
if (
channel_first_attempt.data.calls_started
!= channel_second_attempt.data.calls_started
):
raise self.NotFound(
f"[{self.hostname}] Not found successful calls over the channel."
)
logger.info(
"[%s] Detected successful calls to xDS control plane %s,"
" channel: %s",
self.hostname,
xds_server_uri,
_ChannelzServiceClient.channel_repr(channel),
)
return channel_upd
except self.NotFound:
# Otherwise, keep searching.
# Continue checking other channels to the same target on
# not found.
continue

return channel
except framework.rpc.grpc.RpcError as err:
logger.debug(
f"Unexpected error while checking"
f" channel {channel.ref.channel_id}: {err}"
)
raise

raise self.ChannelNotActive(
f"[{self.hostname}] Client has no"
Expand All @@ -381,7 +359,7 @@ def find_server_channel_with_state(
expected_state_name: str = _ChannelzChannelState.Name(expected_state)
target: str = self.server_target

for channel in self.get_server_channels(target, **rpc_params):
for channel in self.find_channels(target, **rpc_params):
channel_state: _ChannelzChannelState = channel.data.state.state
logger.info(
"[%s] Server channel: %s",
Expand Down Expand Up @@ -416,10 +394,12 @@ def find_server_channel_with_state(
expected_state=expected_state,
)

def get_server_channels(
self, server_target: str, **kwargs
def find_channels(
self,
target: str,
**rpc_params,
) -> Iterable[_ChannelzChannel]:
return self.channelz.find_channels_for_target(server_target, **kwargs)
return self.channelz.find_channels_for_target(target, **rpc_params)

def find_subchannel_with_state(
self, channel: _ChannelzChannel, state: _ChannelzChannelState, **kwargs
Expand Down Expand Up @@ -449,23 +429,65 @@ def find_subchannels_with_state(
subchannels.append(subchannel)
return subchannels

def check_channel_successful_calls(
self, channel: _ChannelzChannel, **kwargs
) -> _ChannelzChannel:
"""Checks if the channel has any successful calls.
We consider the channel is active if channel is in READY state and calls_started is
greater than calls_failed.
def check_channel_in_flight_calls(
self,
channel: _ChannelzChannel,
*,
wait_between_checks: Optional[_timedelta] = None,
**rpc_params,
) -> Optional[_ChannelzChannel]:
"""Checks if the channel has calls that started, but didn't complete.
We consider the channel is active if channel is in READY state and
calls_started is greater than calls_failed.
This method address race where a call to the xDS control plane server
has just started and a channelz request comes in before the call has
had a chance to fail.
With channels to the xDS control plane, the channel can be READY but the
calls could be failing to initialize, f.e. due to a failure to fetch
OAUTH2 token. To increase the confidence that we have a valid channel
with working OAUTH2 tokens, we check whether the channel is in a READY
state with active calls twice with an interval of 2 seconds between the
two attempts. If the OAUTH2 token is not valid, the call would fail and
be caught in either the first attempt, or the second attempt. It is
possible that between the two attempts, a call fails and a new call is
started, so we also test for equality between the started calls of the
two channelz results.
There still exists a possibility that a call fails on fetching OAUTH2
token after 2 seconds (maybe because there is a slowdown in the
system.) If such a case is observed, consider increasing the interval
from 2 seconds to 5 seconds.
Returns updated channel on success, or None on failure.
"""
if not self.calc_calls_in_flight(channel):
return None

if not wait_between_checks:
wait_between_checks = _timedelta(seconds=2)

# Load the channel second time after the timeout.
time.sleep(wait_between_checks.total_seconds())
channel_upd: _ChannelzChannel = self.channelz.get_channel(
channel.ref.channel_id, **rpc_params
)
if (
channel.data.state.state is _ChannelzChannelState.READY
and channel.data.calls_started > channel.data.calls_failed
not self.calc_calls_in_flight(channel_upd)
or channel.data.calls_started != channel_upd.data.calls_started
):
return channel
return None
return channel_upd

raise self.NotFound(
f"[{self.hostname}] Not found successful calls over the channel."
)
@classmethod
def calc_calls_in_flight(cls, channel: _ChannelzChannel) -> int:
cdata: _ChannelzChannelData = channel.data
if cdata.state.state is not _ChannelzChannelState.READY:
return 0

return cdata.calls_started - cdata.calls_succeeded - cdata.calls_failed

class ChannelNotFound(framework.rpc.grpc.GrpcApp.NotFound):
"""Channel with expected status not found"""
Expand Down

0 comments on commit f163b38

Please sign in to comment.