Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Add exponential backoff when retrying replicas #31436

Merged
merged 10 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,10 @@ def __init__(
# DeploymentInfo and bring current deployment to meet new status.
self._target_state: DeploymentTargetState = DeploymentTargetState.default()
self._prev_startup_warning: float = time.time()
# Exponential backoff when retrying a consistently failing deployment
self._last_retry: float = 0.0
self._backoff_time: int = 1
self._max_backoff: int = 64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: _max_backoff_time_s, _backoff_time_s

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider making these parametrizable via env var

Copy link
Contributor Author

@zcin zcin Jan 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions, applied! I made the backoff factor and the max backoff time env variables.

self._replica_constructor_retry_counter: int = 0
self._replicas: ReplicaStateContainer = ReplicaStateContainer()
self._curr_status_info: DeploymentStatusInfo = DeploymentStatusInfo(
Expand Down Expand Up @@ -1104,6 +1108,7 @@ def _set_target_state(self, target_info: DeploymentInfo) -> None:
self._name, DeploymentStatus.UPDATING
)
self._replica_constructor_retry_counter = 0
self._backoff_time = 1

logger.debug(f"Deploying new version of {self._name}: {target_state.version}.")

Expand Down Expand Up @@ -1308,28 +1313,43 @@ def _scale_deployment_replicas(self) -> bool:
)
to_add = max(delta_replicas - stopping_replicas, 0)
if to_add > 0:
# Exponential backoff
failed_to_start_threshold = min(
zcin marked this conversation as resolved.
Show resolved Hide resolved
MAX_DEPLOYMENT_CONSTRUCTOR_RETRY_COUNT,
self._target_state.num_replicas * 3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for choosing self._target_state.num_replicas * 3?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah is it to compare against the total number of replica restarts across the deployment? (So on average each replica has failed 3 times?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the threshold used for setting the deployment unhealthy: code pointer. Basically, perform exponential backoff after a replica fails 3 times and the deployment is determined unhealthy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah is it to compare against the total number of replica restarts across the deployment? (So on average each replica has failed 3 times?)

Yup, I believe so.

)
if self._replica_constructor_retry_counter >= failed_to_start_threshold:
# Wait 1, 2, 4, ... seconds before consecutive retries, with random
# offset added to avoid synchronization
if (
time.time() - self._last_retry
< self._backoff_time + random.uniform(0, 3)
):
return replicas_stopped

self._last_retry = time.time()
logger.info(
f"Adding {to_add} replica{'s' if to_add > 1 else ''} "
f"to deployment '{self._name}'."
)
for _ in range(to_add):
replica_name = ReplicaName(self._name, get_random_letters())
new_deployment_replica = DeploymentReplica(
self._controller_name,
self._detached,
replica_name.replica_tag,
replica_name.deployment_tag,
self._target_state.version,
)
new_deployment_replica.start(
self._target_state.info, self._target_state.version
)
for _ in range(to_add):
replica_name = ReplicaName(self._name, get_random_letters())
new_deployment_replica = DeploymentReplica(
self._controller_name,
self._detached,
replica_name.replica_tag,
replica_name.deployment_tag,
self._target_state.version,
)
new_deployment_replica.start(
self._target_state.info, self._target_state.version
)

self._replicas.add(ReplicaState.STARTING, new_deployment_replica)
logger.debug(
"Adding STARTING to replica_tag: "
f"{replica_name}, deployment: {self._name}"
)
self._replicas.add(ReplicaState.STARTING, new_deployment_replica)
logger.debug(
"Adding STARTING to replica_tag: "
f"{replica_name}, deployment: {self._name}"
)

elif delta_replicas < 0:
replicas_stopped = True
Expand Down Expand Up @@ -1407,10 +1427,10 @@ def _check_curr_status(self) -> bool:
name=self._name,
status=DeploymentStatus.UNHEALTHY,
message=(
f"The Deployment failed to start {failed_to_start_count} "
"times in a row. This may be due to a problem with the "
"deployment constructor or the initial health check failing. "
"See logs for details."
f"The Deployment failed to start {failed_to_start_count} times "
"in a row. This may be due to a problem with the deployment "
"constructor or the initial health check failing. See logs for "
f"details. Retrying after {self._backoff_time} seconds."
),
)
return False
Expand Down Expand Up @@ -1453,6 +1473,7 @@ def _check_startup_replicas(
"""
slow_replicas = []
transitioned_to_running = False
replicas_failed = False
for replica in self._replicas.pop(states=[original_state]):
start_status = replica.check_started()
if start_status == ReplicaStartupStatus.SUCCEEDED:
Expand All @@ -1466,6 +1487,7 @@ def _check_startup_replicas(
# Increase startup failure counter if we're tracking it
self._replica_constructor_retry_counter += 1

replicas_failed = True
replica.stop(graceful=False)
self._replicas.add(ReplicaState.STOPPING, replica)
elif start_status in [
Expand All @@ -1485,6 +1507,18 @@ def _check_startup_replicas(
else:
self._replicas.add(original_state, replica)

# If replicas have failed enough times, execute exponential backoff
# Wait 1, 2, 4, ... seconds before consecutive retries
failed_to_start_threshold = min(
MAX_DEPLOYMENT_CONSTRUCTOR_RETRY_COUNT,
self._target_state.num_replicas * 3,
)
if (
replicas_failed
and self._replica_constructor_retry_counter > failed_to_start_threshold
):
self._backoff_time = min(2 * self._backoff_time, self._max_backoff)

return slow_replicas, transitioned_to_running

def _check_and_update_replicas(self) -> bool:
Expand Down
51 changes: 51 additions & 0 deletions python/ray/serve/tests/test_deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,57 @@ def test_deploy_with_transient_constructor_failure(
assert deployment_state.curr_status_info.status == DeploymentStatus.HEALTHY


@pytest.mark.parametrize("mock_deployment_state", [False], indirect=True)
@patch.object(DriverDeploymentState, "_get_all_node_ids")
def test_exponential_backoff(mock_get_all_node_ids, mock_deployment_state):
"""Test exponential backoff."""
deployment_state, timer = mock_deployment_state
mock_get_all_node_ids.return_value = [(str(i), str(i)) for i in range(2)]

b_info_1, b_version_1 = deployment_info(num_replicas=2)
updating = deployment_state.deploy(b_info_1)
assert updating
assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING

_constructor_failure_loop_two_replica(deployment_state, 3)
assert deployment_state._replica_constructor_retry_counter == 6
last_retry = timer.time()

for i in range(7):
while timer.time() - last_retry < 2**i:
deployment_state.update()
assert deployment_state._replica_constructor_retry_counter == 6 + 2 * i
# Check that during backoff time, no replicas are created
check_counts(deployment_state, total=0)
timer.advance(0.1) # simulate time passing between each call to udpate

# Skip past random additional backoff time used to avoid synchronization
timer.advance(5)

# Set new replicas to fail consecutively
check_counts(deployment_state, total=0) # No replicas
deployment_state.update()
last_retry = timer.time() # This should be time at which replicas were retried
check_counts(deployment_state, total=2) # Two new replicas
replica_1 = deployment_state._replicas.get()[0]
replica_2 = deployment_state._replicas.get()[1]
replica_1._actor.set_failed_to_start()
replica_2._actor.set_failed_to_start()
timer.advance(0.1) # simulate time passing between each call to udpate

# Now the replica should be marked STOPPING after failure.
deployment_state.update()
check_counts(deployment_state, total=2, by_state=[(ReplicaState.STOPPING, 2)])
timer.advance(0.1) # simulate time passing between each call to udpate

# Once it's done stopping, replica should be removed.
replica_1._actor.set_done_stopping()
replica_2._actor.set_done_stopping()
deployment_state.update()
check_counts(deployment_state, total=0)
timer.advance(0.1) # simulate time passing between each call to udpate


@pytest.fixture
def mock_deployment_state_manager(request) -> Tuple[DeploymentStateManager, Mock]:
ray.init()
Expand Down