Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation branch #176

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import List, Any
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
from deltacat.utils.ray_utils.retry_handler.batch_scaling_interface import BatchScalingInterface
class AIMDBasedBatchScalingStrategy(BatchScalingInterface):
"""
Default batch scaling parameters for if the client does not provide their own batch_scaling parameters
"""
def __init__(self,
task_infos: List[TaskInfoObject],
initial_batch_size: int,
max_batch_size: int,
min_batch_size: int,
additive_increase: int,
multiplicative_decrease: float):
self.task_infos = task_infos
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
self.batch_index = 0
self.batch_size = initial_batch_size
self.max_batch_size = max_batch_size
self.min_batch_size = min_batch_size
self.additive_increase = additive_increase
self.multiplicative_decrease = multiplicative_decrease

#dictionary
self.task_completion_status: Dict[str, bool] = {task.task_id: False for task in self.task_infos}
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

def has_next_batch(self) -> bool:
"""
Returns the list of tasks included in the next batch of whatever size based on AIMD
"""
return self.batch_index < len(self.task_infos)


def next_batch(self) -> List[TaskInfoObject]:
"""
If there are no more tasks to execute that can not create a batch, return False
"""
batch_end = min(self.batch_index + self.batch_size, len(self.task_infos))
batch = self.task_infos[self.batch_index:batch_end]
self.batch_index = batch_end
return batch

def mark_task_complete(self, task_info: TaskInfoObject):
self.task_completion_status[task_info.task_id] = True
if (self.batch_size + self.additive_increase) > self.max_batch_size:
self.batch_size = self.max_batch_size
else:
self.batch_size = self.batch_size + self.additive_increase
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

def mark_task_failed(self, task_info: TaskInfoObject):
self.task_completion_status[task_info.task_id] = False
if (self.batch_size * self.multiplicative_decrease) < self.min_batch_size:
self.batch_size = self.min_batch_size
else:
self.batch_size = self.batch_size * self.multiplicative_decrease

27 changes: 27 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
This module represents a straggler detection and retry handler framework

Within retry_strategy_config.py, the client can provide 3 parameters to start_tasks_execution to perform retries and detect stragglers
Params:
1. ray_remote_task_info: A list of Ray task objects
2. scaling_strategy: Batch scaling parameters for how many tasks to execute per batch (Optional)
a. If not provided, a default AIMD (additive increase, multiplicative decrease) strategy will be assigned for retries
3. straggler_detection: Client-provided class that holds logic for how they want to detect straggler tasks (Optional)
a. Client algorithm must inherit the interface for detection which will be used in wait_and_get_results

Use cases:
1. Notifying progress
This will be done through ProgressNotifierInterface. The client can implement has_progress and send_progress from the interface
to recieve updates on task level progress. This can be an SNSQueue or any type of indicator the client may choose.
2. Detecting stragglers
Given the straggler detection algorithm implemented by StragglerDetectionInterface, the method is_straggler will inform
the customer if the current node is a straggler according to their own logic. In order to make their decision, we will provide them
with TaskContext that contains fields and data that the client can use to decide if a task is a straggler or not.
3. Retrying retryable exceptions
Within the failure directory, there are common errors that are retryable and when detected as an instance
of the retryable class, will cause the task to be retried when the exception is caught. If the client would like
to create their own exceptions to be handles, they can create a class that is an extension of retryable_error or
non_retryable_error and the framework should handle it based on the configuration strategy.




31 changes: 31 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/batch_scaling_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Any, Protocol
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
class BatchScalingInterface(Protocol):
"""
Interface for a generic batch scaling that the client can provide.
"""
def has_next_batch(self) -> bool:
"""
Returns true if there are tasks remaining in the overall List of tasks to create a new batch
"""
pass
def next_batch(self, task_info: TaskInfoObject) -> List:
"""
Gets the next batch to execute on
"""
pass
def mark_task_complete(self, task_info: TaskInfoObject) -> None:
"""
If the task has been completed, mark some field of it as true
so we know what tasks are completed and what need to be executed
"""
pass

def mark_batch_failed(self, task_info: TaskInfoObject) -> None:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
"""
If the task returns the exeption that was caught, we mark the task as failed and
decrease multiplicative
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
:param task_info:
:return:
"""
pass
16 changes: 16 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/exception_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import List, Optional
from ray_manager.models.ray_remote_task_exception_retry_strategy_config import RayRemoteTaskExceptionRetryConfig
def get_retry_strategy_config_for_known_exception(exception: Exception,
exception_retry_strategy_configs: List[RayRemoteTaskExceptionRetryConfig]) -> Optional[RayRemoteTaskExceptionRetryConfig]:
"""
Checks whether the exception seen is recognized as a retryable error or not
"""
for exception_retry_strategy_config in exception_retry_strategy_configs:
if type(exception) == type(exception_retry_strategy_config.exception):
return exception_retry_strategy_config

for exception_retry_strategy_config in exception_retry_strategy_configs:
if isinstance(exception, type(exception_retry_strategy_config.exception)):
return exception_retry_strategy_config

return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.failures.retryable_error import RetryableError

class AWSSecurityTokenRateExceededException(RetryableError):

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.failures.retryable_error import RetryableError

class CairnsClientException(RetryableError):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from exceptions import Exception
class NonRetryableError(Exception):
"""
Class represents a non-retryable error
"""
def __init__(self, *args: object):
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from exceptions import Exception
class RetryableError(Exception):
"""
Class for errors that can be retried
"""
def __init__(self, *args: object) --> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from deltacat.utils.ray_utils.retry_handler.failures.non_retryable_error import NonRetryableError

class UnexpectedRayTaskError(NonRetryableError):
"""
An error class that denotes that operation cannot be completed because of Unexpected Ray task error
"""

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import List, Protocol
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
class ProgressNotifierInterface(Protocol):
"""
Interface for client injected progress notification system.
"""
def has_heartbeat(self, task_info: TaskInfoObject) -> bool:
"""
Sends progress of current task to parent task
"""
pass
def send_heartbeat(self, parent_task_info: TaskInfoObject) -> bool:
"""
Tells parent task if the current task has a heartbeat or not
"""
pass

175 changes: 175 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/ray_task_submission_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from __future__ import annotations
from typing import Any, Dict, List, cast, Optional
from deltacat.utils.ray_utils.retry_handler.ray_remote_tasks_batch_scaling_strategy import RayRemoteTasksBatchScalingStrategy
import ray
import time
import logging
from deltacat.logs import configure_logger
from deltacat.utils.ray_utils.retry_handler.task_execution_error import RayRemoteTaskExecutionError
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
from deltacat.utils.ray_utils.retry_handler.retry_strategy_config import get_retry_strategy_config_for_known_exception

logger = configure_logger(logging.getLogger(__name__))

@ray.remote
def submit_single_task(taskObj: TaskInfoObject, TaskContext: Optional[Interface] = None) -> Any:
"""
Submits a single task for execution, handles any exceptions that may occur during execution,
and applies appropriate retry strategies if they are defined.
"""
try:
taskObj.attempt_count += 1
curr_attempt = taskObj.attempt_count
logger.debug(f"Executing the submitted Ray remote task as part of attempt number: {current_attempt_number}")
return taskObj.task_callable(taskObj.task_input)
except (Exception) as exception:
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, taskObj.exception_retry_strategy_configs)
if exception_retry_strategy_config is not None:
return TaskExecutionError(exception_retry_strategy_config.exception, taskObj)

logger.error(f"The exception thrown by submitted Ray task during attempt number: {current_attempt_number} is non-retryable or unexpected, hence throwing Non retryable exception: {exception}")
raise UnexpectedRayTaskError(str(exception))

class RayTaskSubmissionHandler:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
"""
Starts execution of all given a list of Ray tasks with optional arguments: scaling strategy and straggler detection
"""
def start_tasks_execution(self,
ray_remote_task_infos: List[TaskInfoObject],
scaling_strategy: Optional[BatchScalingStrategy] = None,
straggler_detection: Optional[StragglerDetectionInterface] = None,
retry_strategy: Optional[RetryTaskInterface] = None,
task_context: Optional[TaskContext] = None) -> None:
"""
Prepares and initiates the execution of a batch of tasks and can optionally support
custom client batch scaling, straggler detection, and task context
"""
if scaling_strategy is None:
scaling_strategy = AIMDBasedBatchScalingStrategy(ray_remote_task_infos, 50, 100, 10, 2, 0.5)
if retry_strategy is None:
retry_strategy = RetryTaskDefault(max_retries = 3)

active_tasks = []
attempts = {task.task_id: 0 for task in ray_remote_task_infos}

while scaling_strategy.has_next_batch():
current_batch = scaling_strategy.next_batch()
for task in current_batch:
self._submit_task(task)
active_tasks.append(task) #maybe should be task_id

while active_tasks:
completed_task = self._get_task_results(1)
if isInstance(completed_task, TaskExecutionError):
scaling_strategy.mark_task_failed(completed_task)
if retry_strategy.should_retry(task, completed_task): #add max retry here
attempts[task.task_id] += 1
self.ray_remote_task_infos.append(completed_task)
active_tasks.remove(completed_task)
else:
scaling_strategy.mark_task_complete(completed_task)
active_tasks.remove(task)
# If straggler detection is enabled, iterate over the active tasks again
if straggler_detection is not None:
for task in active_tasks[:]:
if straggler_detection.is_straggler(task, task_context):
ray.cancel(task)
active_tasks.remove(task)
# If you want to re-add the cancelled stragglers to the task queue
self.ray_remote_task_infos.append(task)

def _wait_and_get_all_task_results(self) -> List[Any]:
return self._get_task_results(self.num_of_submitted_tasks)

def _get_task_results(self, num_of_results: int) -> List[Any]:
"""
Gets results from a list of tasks to be executed, and catches exceptions to manage the retry strategy.
Optional: Given a StragglerDetectionInterface, can detect and handle straggler tasks according to the client logic
"""
if not self.unfinished_promises or num_of_results == 0:
return []
elif num_of_results > len(self.unfinished_promises):
num_of_results = len(self.unfinished_promises)

finished, self.unfinished_promises = ray.wait(self.unfinished_promises, num_of_results)
successful_results = []

for finished in finished:
finished_result = None
try:
finished_result = ray.get(finished)
except (Exception) as exception:
#if exception send to method handle_ray_exception to determine what to do and assign the corresp error
finished_result = self._handle_ray_exception(exception=exception, ray_remote_task_info=self.task_promise_obj_ref_to_task_info_map[str(finished_promise)] )#evaluate the exception and return the error

exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(finished_result.exception,
finished_result.ray_remote_task_info.exception_retry_strategy_configs)
if (exception_retry_strategy_config is None or finished_result.ray_remote_task_info.num_of_attempts > exception_retry_strategy_config.max_retry_attempts):
logger.error(f"The submitted task has exhausted all the maximum retries configured and finally throws exception - {finished_result.exception}")
raise finished_result.exception
self._update_ray_remote_task_options_on_exception(finished_result.exception, finished_result.ray_remote_task_info)
self.unfinished_promises.append(self._invoke_ray_remote_task(ray_remote_task_info=finished_result.ray_remote_task_info))
else:
successful_results.append(finished_result)
del self.task_promise_obj_ref_to_task_info_map[str(finished_promise)]

num_of_successful_results = len(successful_results)
self.num_of_submitted_tasks_completed += num_of_successful_results
self.current_batch_size -= num_of_successful_results

self._enqueue_new_tasks(num_of_successful_results)

if num_of_successful_results < num_of_results:
successful_results.extend(self._get_task_results(num_of_results - num_of_successful_results))
return successful_results
else:
return successful_results


def _enqueue_new_tasks(self, num_of_tasks: int) -> None:
"""
Helper method to submit a specified number of tasks
"""
new_tasks_submitted = self.remaining_ray_remote_task_infos[:num_of_tasks]
num_of_new_tasks_submitted = len(new_tasks_submitted)
self._submit_tasks(new_tasks_submitted)
self.remaining_ray_remote_task_infos = self.remaining_ray_remote_task_infos[num_of_tasks:]
self.current_batch_size += num_of_new_tasks_submitted
logger.info(f"Enqueued {num_of_new_tasks_submitted} new tasks. Current concurrency of tasks execution: {self.current_batch_size}, Current Task progress: {self.num_of_submitted_tasks_completed}/{self.num_of_submitted_tasks}")

def _submit_tasks(self, info_objs: List[TaskInfoObject]) -> None:
for info_obj in info_objs:
time.sleep(0.005)
self.unfinished_promises.append(self._invoke_ray_remote_task(info_obj))

#replace with ray.options
def _invoke_ray_remote_task(self, ray_remote_task_info: RayRemoteTaskInfo) -> Any:
#change to using ray.options
ray_remote_task_options_arguments = dict()

if ray_remote_task_info.ray_remote_task_options.memory:
ray_remote_task_options_arguments['memory'] = ray_remote_task_info.ray_remote_task_options.memory

if ray_remote_task_info.ray_remote_task_options.num_cpus:
ray_remote_task_options_arguments['num_cpus'] = ray_remote_task_info.ray_remote_task_options.num_cpus

if ray_remote_task_info.ray_remote_task_options.placement_group:
ray_remote_task_options_arguments['placement_group'] = ray_remote_task_info.ray_remote_task_options.placement_group

ray_remote_task_promise_obj_ref = submit_single_task.options(**ray_remote_task_options_arguments).remote(ray_remote_task_info=ray_remote_task_info)
self.task_promise_obj_ref_to_task_info_map[str(ray_remote_task_promise_obj_ref)] = ray_remote_task_info

return ray_remote_task_promise_obj_ref

#replace with ray.options
def _update_ray_remote_task_options_on_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo):
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, ray_remote_task_info.exception_retry_strategy_configs)
if exception_retry_strategy_config and ray_remote_task_info.ray_remote_task_options.memory:
logger.info(f"Updating the Ray remote task options after encountering exception: {exception}")
ray_remote_task_memory_multiply_factor = exception_retry_strategy_config.ray_remote_task_memory_multiply_factor
ray_remote_task_info.ray_remote_task_options.memory *= ray_remote_task_memory_multiply_factor
logger.info(f"Updated ray remote task options Memory: {ray_remote_task_info.ray_remote_task_options.memory}")
#replace with own exceptions
def _handle_ray_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo) -> RayRemoteTaskExecutionError:
logger.error(f"Ray remote task failed with {type(exception)} Ray exception: {exception}")
if type(exception).__name__ == "AWSSecurityTokenRateExceededException(RetryableError)"
Loading
Loading