diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e3a333b..2101823 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,8 +2,8 @@ name: test on: push: branches: [ "main" ] - pull_request: - branches: [ "main" ] + pull_request_target: + branches: ["main"] concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} diff --git a/docs/index.rst b/docs/index.rst index 270fd7e..d626fc4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,4 +1,4 @@ -Welcome to Ray provider documentation! +Welcome to the Ray provider documentation! ====================================== .. toctree:: @@ -12,14 +12,14 @@ Welcome to Ray provider documentation! API Reference Contributing -This repository provides tools for integrating `Apache Airflow®`_ with Ray, enabling the orchestration of Ray jobs within Airflow DAGs. It includes a decorator, two operators, and one trigger designed to efficiently manage and monitor Ray jobs and services. +This repository contains modules for integrating `Apache Airflow®`_ with Ray, enabling the orchestration of Ray jobs from Airflow DAGs. It includes a decorator, two operators, and one trigger designed to efficiently manage and monitor Ray jobs and services. Benefits of using this provider include: - **Integration**: Incorporate Ray jobs into Airflow DAGs for unified workflow management. - **Distributed computing**: Use Ray's distributed capabilities within Airflow pipelines for scalable ETL, LLM fine-tuning etc. - **Monitoring**: Track Ray job progress through Airflow's user interface. -- **Dependency management**: Define and manage dependencies between Ray jobs and other tasks in DAGs. +- **Dependency management**: Define and manage dependencies between Ray jobs and other tasks in Airflow DAGs. - **Resource allocation**: Run Ray jobs alongside other task types within a single pipeline. .. _Apache Airflow®: https://airflow.apache.org/ @@ -39,15 +39,15 @@ Quickstart See the :doc:`Getting Started ` page for detailed instructions on how to begin using the provider. -What is the Ray provider? +Why use Airflow with Ray? ------------------------- -Enterprise data value extraction involves two crucial components: +Value creation from data in an enterprise environment involves two crucial components: -- Data Engineering -- Data Science/ML/AI +- Data Engineering (ETL/ELT/Infrastructure Management) +- Data Science (ML/AI) -While Airflow excels at data engineering tasks through its extensive plugin ecosystem, it generally relies on external systems when dealing with large-scale ETL(100s GB to PB scale) or AI tasks such as fine-tuning & deploying LLMs etc. +While Airflow excels at orchestrating both, data engineering and data science related tasks through its extensive provider ecosystem, it often relies on external systems when dealing with large-scale (100s GB to PB scale) data and compute (GPU) requirements, such as fine-tuning & deploying LLMs etc. Ray is a particularly powerful platform for handling large scale computations and this provider makes it very straightforward to orchestrate Ray jobs from Airflow. @@ -57,12 +57,11 @@ Ray is a particularly powerful platform for handling large scale computations an :width: 499 :height: 561 -The architecture diagram above shows how we can deploy both Airflow & Ray on a Kubernetes cluster for elastic compute. - +The architecture diagram above shows that we can run both, Airflow and Ray side by side on Kubernetes to leverage the best of both worlds. Airflow can be used to orchestrate Ray jobs and services, while Ray can be used to run distributed computations. Use Cases ^^^^^^^^^ -- **Scalable ETL**: Orchestrate and monitor Ray jobs on on-demand compute clusters using the Ray Data library. These operations could be custom Python code or ML model inference. +- **Scalable ETL**: Orchestrate and monitor Ray jobs to perform distributed ETL for heavy data loads on on-demand compute clusters using the Ray Data library. - **Model Training**: Schedule model training or fine-tuning jobs on flexible cadences (daily/weekly/monthly). Benefits include: * Optimize resource utilization by scheduling Ray jobs during cost-effective periods diff --git a/ray_provider/decorators/ray.py b/ray_provider/decorators/ray.py index ae52f57..acb5c45 100644 --- a/ray_provider/decorators/ray.py +++ b/ray_provider/decorators/ray.py @@ -31,11 +31,11 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob): template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs") - def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None: + def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None: + job_timeout_seconds: int = config.get("job_timeout_seconds", 600) self.config = config self.kwargs = kwargs - super().__init__(conn_id="", entrypoint="python script.py", runtime_env={}, **kwargs) def get_config(self, context: Context, config: Callable[..., dict[str, Any]], **kwargs: Any) -> dict[str, Any]: diff --git a/ray_provider/operators/ray.py b/ray_provider/operators/ray.py index 1c5bbf2..02b6a73 100644 --- a/ray_provider/operators/ray.py +++ b/ray_provider/operators/ray.py @@ -116,7 +116,7 @@ class SubmitRayJob(BaseOperator): :param gpu_device_plugin_yaml: URL or path to the GPU device plugin YAML file. Defaults to NVIDIA's plugin. :param fetch_logs: Whether to fetch logs from the Ray job. Defaults to True. :param wait_for_completion: Whether to wait for the job to complete before marking the task as finished. Defaults to True. - :param job_timeout_seconds: Maximum time to wait for job completion in seconds. Defaults to 600 seconds. + :param job_timeout_seconds: Maximum time to wait for job completion in seconds. Defaults to 600 seconds. Set to 0 if you want the job to run indefinitely without timeouts. :param poll_interval: Interval between job status checks in seconds. Defaults to 60 seconds. :param xcom_task_key: XCom key to retrieve the dashboard URL. Defaults to None. """ @@ -168,7 +168,7 @@ def __init__( self.gpu_device_plugin_yaml = gpu_device_plugin_yaml self.fetch_logs = fetch_logs self.wait_for_completion = wait_for_completion - self.job_timeout_seconds = job_timeout_seconds + self.job_timeout_seconds = timedelta(seconds=job_timeout_seconds) if job_timeout_seconds > 0 else None self.poll_interval = poll_interval self.xcom_task_key = xcom_task_key self.dashboard_url: str | None = None @@ -185,8 +185,7 @@ def on_kill(self) -> None: if hasattr(self, "hook") and self.job_id: self.log.info(f"Deleting Ray job {self.job_id} due to task kill.") self.hook.delete_ray_job(self.dashboard_url, self.job_id) - if self.ray_cluster_yaml: - self._delete_cluster() + self._delete_cluster() @cached_property def hook(self) -> PodOperatorHookProtocol: @@ -262,48 +261,55 @@ def execute(self, context: Context) -> str: :raises AirflowException: If the job fails, is cancelled, or reaches an unexpected state. """ - self._setup_cluster(context=context) - - self.dashboard_url = self._get_dashboard_url(context) - - self.job_id = self.hook.submit_ray_job( - dashboard_url=self.dashboard_url, - entrypoint=self.entrypoint, - runtime_env=self.runtime_env, - entrypoint_num_cpus=self.num_cpus, - entrypoint_num_gpus=self.num_gpus, - entrypoint_memory=self.memory, - entrypoint_resources=self.ray_resources, - ) - self.log.info(f"Ray job submitted with id: {self.job_id}") - - if self.wait_for_completion: - current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) - self.log.info(f"Current job status for {self.job_id} is: {current_status}") - - if current_status not in self.terminal_states: - self.log.info("Deferring the polling to RayJobTrigger...") - self.defer( - trigger=RayJobTrigger( - job_id=self.job_id, - conn_id=self.conn_id, - xcom_dashboard_url=self.dashboard_url, - poll_interval=self.poll_interval, - fetch_logs=self.fetch_logs, - ), - method_name="execute_complete", - timeout=timedelta(seconds=self.job_timeout_seconds), - ) - elif current_status == JobStatus.SUCCEEDED: - self.log.info("Job %s completed successfully", self.job_id) - elif current_status == JobStatus.FAILED: - raise AirflowException(f"Job failed:\n{self.job_id}") - elif current_status == JobStatus.STOPPED: - raise AirflowException(f"Job was cancelled:\n{self.job_id}") - else: - raise AirflowException(f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`") - - return self.job_id + try: + self._setup_cluster(context=context) + + self.dashboard_url = self._get_dashboard_url(context) + + self.job_id = self.hook.submit_ray_job( + dashboard_url=self.dashboard_url, + entrypoint=self.entrypoint, + runtime_env=self.runtime_env, + entrypoint_num_cpus=self.num_cpus, + entrypoint_num_gpus=self.num_gpus, + entrypoint_memory=self.memory, + entrypoint_resources=self.ray_resources, + ) + self.log.info(f"Ray job submitted with id: {self.job_id}") + + if self.wait_for_completion: + current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) + self.log.info(f"Current job status for {self.job_id} is: {current_status}") + + if current_status not in self.terminal_states: + self.log.info("Deferring the polling to RayJobTrigger...") + self.defer( + trigger=RayJobTrigger( + job_id=self.job_id, + conn_id=self.conn_id, + xcom_dashboard_url=self.dashboard_url, + ray_cluster_yaml=self.ray_cluster_yaml, + gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, + poll_interval=self.poll_interval, + fetch_logs=self.fetch_logs, + ), + method_name="execute_complete", + timeout=self.job_timeout_seconds, + ) + elif current_status == JobStatus.SUCCEEDED: + self.log.info("Job %s completed successfully", self.job_id) + elif current_status == JobStatus.FAILED: + raise AirflowException(f"Job failed:\n{self.job_id}") + elif current_status == JobStatus.STOPPED: + raise AirflowException(f"Job was cancelled:\n{self.job_id}") + else: + raise AirflowException( + f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`" + ) + return self.job_id + except Exception as e: + self._delete_cluster() + raise AirflowException(f"SubmitRayJob operator failed due to {e}. Cleaning up resources...") def execute_complete(self, context: Context, event: dict[str, Any]) -> None: """ diff --git a/ray_provider/triggers/ray.py b/ray_provider/triggers/ray.py index 8a14457..745c74f 100644 --- a/ray_provider/triggers/ray.py +++ b/ray_provider/triggers/ray.py @@ -30,6 +30,8 @@ def __init__( job_id: str, conn_id: str, xcom_dashboard_url: str | None, + ray_cluster_yaml: str | None, + gpu_device_plugin_yaml: str, poll_interval: int = 30, fetch_logs: bool = True, ): @@ -37,6 +39,8 @@ def __init__( self.job_id = job_id self.conn_id = conn_id self.dashboard_url = xcom_dashboard_url + self.ray_cluster_yaml = ray_cluster_yaml + self.gpu_device_plugin_yaml = gpu_device_plugin_yaml self.fetch_logs = fetch_logs self.poll_interval = poll_interval @@ -52,6 +56,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "job_id": self.job_id, "conn_id": self.conn_id, "xcom_dashboard_url": self.dashboard_url, + "ray_cluster_yaml": self.ray_cluster_yaml, + "gpu_device_plugin_yaml": self.gpu_device_plugin_yaml, "fetch_logs": self.fetch_logs, "poll_interval": self.poll_interval, }, @@ -66,6 +72,28 @@ def hook(self) -> RayHook: """ return RayHook(conn_id=self.conn_id) + async def cleanup(self) -> None: + """ + Cleanup method to ensure resources are properly deleted. This will be called when the trigger encounters an exception. + + Example scenario: A job is submitted using the @ray.task decorator with a Ray specification. After the cluster is started + and the job is submitted, the trigger begins tracking its progress. However, if the job is stopped through the UI at this stage, the cluster + resources are not deleted. + + """ + try: + if self.ray_cluster_yaml: + self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}") + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml + ) + self.log.info("Ray cluster deletion process completed") + else: + self.log.info("No Ray cluster YAML provided, skipping cluster deletion") + except Exception as e: + self.log.error(f"Unexpected error during cleanup: {str(e)}") + async def _poll_status(self) -> None: while not self._is_terminal_state(): await asyncio.sleep(self.poll_interval) @@ -109,6 +137,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: } ) except Exception as e: + self.log.error(f"Error occurred: {str(e)}") + await self.cleanup() yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id}) def _is_terminal_state(self) -> bool: diff --git a/tests/decorators/test_ray_decorators.py b/tests/decorators/test_ray_decorators.py index b4833b0..109f748 100644 --- a/tests/decorators/test_ray_decorators.py +++ b/tests/decorators/test_ray_decorators.py @@ -1,3 +1,4 @@ +from datetime import timedelta from unittest.mock import MagicMock, patch import pytest @@ -38,7 +39,7 @@ def dummy_callable(): assert operator.ray_resources == {"custom_resource": 1} assert operator.fetch_logs == True assert operator.wait_for_completion == True - assert operator.job_timeout_seconds == 300 + assert operator.job_timeout_seconds == timedelta(seconds=300) assert operator.poll_interval == 30 assert operator.xcom_task_key == "ray_result" @@ -59,7 +60,7 @@ def dummy_callable(): assert operator.ray_resources is None assert operator.fetch_logs == True assert operator.wait_for_completion == True - assert operator.job_timeout_seconds == 600 + assert operator.job_timeout_seconds == timedelta(seconds=600) assert operator.poll_interval == 60 assert operator.xcom_task_key is None diff --git a/tests/operators/test_ray_operators.py b/tests/operators/test_ray_operators.py index a22e62e..11d3d0c 100644 --- a/tests/operators/test_ray_operators.py +++ b/tests/operators/test_ray_operators.py @@ -1,10 +1,12 @@ +from datetime import timedelta from unittest.mock import MagicMock, Mock, patch import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from ray.job_submission import JobStatus from ray_provider.operators.ray import DeleteRayCluster, SetupRayCluster, SubmitRayJob +from ray_provider.triggers.ray import RayJobTrigger class TestSetupRayCluster: @@ -160,10 +162,32 @@ def test_init(self): assert operator.gpu_device_plugin_yaml == "https://example.com/plugin.yml" assert operator.fetch_logs == True assert operator.wait_for_completion == True - assert operator.job_timeout_seconds == 1200 + assert operator.job_timeout_seconds == timedelta(seconds=1200) assert operator.poll_interval == 30 assert operator.xcom_task_key == "task.key" + def test_init_no_timeout(self): + operator = SubmitRayJob( + task_id="test_task", + conn_id="test_conn", + entrypoint="python script.py", + runtime_env={"pip": ["package1", "package2"]}, + num_cpus=2, + num_gpus=1, + memory=1000, + resources={"custom_resource": 1}, + ray_cluster_yaml="cluster.yaml", + kuberay_version="1.0.0", + update_if_exists=True, + gpu_device_plugin_yaml="https://example.com/plugin.yml", + fetch_logs=True, + wait_for_completion=True, + job_timeout_seconds=0, + poll_interval=30, + xcom_task_key="task.key", + ) + assert operator.job_timeout_seconds is None + def test_on_kill(self, mock_hook): operator = SubmitRayJob(task_id="test_task", conn_id="test_conn", entrypoint="python script.py", runtime_env={}) operator.job_id = "test_job_id" @@ -388,3 +412,182 @@ def test_delete_cluster_exception(self, mock_ray_hook): assert str(exc_info.value) == "Cluster deletion failed" mock_hook.delete_ray_cluster.assert_called_once() + + @pytest.mark.parametrize( + "xcom_task_key, expected_task, expected_key", + [ + ("task.key", "task", "key"), + ("single_key", None, "single_key"), + ], + ) + def test_get_dashboard_url_xcom_variants(self, operator, context, xcom_task_key, expected_task, expected_key): + operator.xcom_task_key = xcom_task_key + context["ti"].xcom_pull.return_value = "http://dashboard.url" + + result = operator._get_dashboard_url(context) + + assert result == "http://dashboard.url" + if expected_task: + context["ti"].xcom_pull.assert_called_once_with(task_ids=expected_task, key=expected_key) + else: + context["ti"].xcom_pull.assert_called_once_with(task_ids=context["task"].task_id, key=expected_key) + + def test_execute_job_unexpected_state(self, mock_hook, context): + operator = SubmitRayJob( + task_id="test_task", + conn_id="test_conn", + entrypoint="python script.py", + runtime_env={}, + wait_for_completion=True, + ) + mock_hook.submit_ray_job.return_value = "test_job_id" + mock_hook.get_ray_job_status.return_value = "UNEXPECTED_STATE" + + with patch.object(operator, "_setup_cluster"), pytest.raises(TaskDeferred) as exc_info: + operator.execute(context) + + assert isinstance(exc_info.value.trigger, RayJobTrigger) + + @pytest.mark.parametrize("dashboard_url", [None, "http://dashboard.url"]) + def test_execute_defer(self, mock_hook, context, dashboard_url): + operator = SubmitRayJob( + task_id="test_task", + conn_id="test_conn", + entrypoint="python script.py", + runtime_env={}, + wait_for_completion=True, + ray_cluster_yaml="cluster.yaml", + gpu_device_plugin_yaml="gpu_plugin.yaml", + poll_interval=30, + fetch_logs=True, + job_timeout_seconds=600, + ) + mock_hook.submit_ray_job.return_value = "test_job_id" + mock_hook.get_ray_job_status.return_value = JobStatus.PENDING + + with patch.object(operator, "_setup_cluster"), patch.object( + operator, "_get_dashboard_url", return_value=dashboard_url + ), pytest.raises(TaskDeferred) as exc_info: + operator.execute(context) + + trigger = exc_info.value.trigger + assert isinstance(trigger, RayJobTrigger) + assert trigger.job_id == "test_job_id" + assert trigger.conn_id == "test_conn" + assert trigger.dashboard_url == dashboard_url + assert trigger.ray_cluster_yaml == "cluster.yaml" + assert trigger.gpu_device_plugin_yaml == "gpu_plugin.yaml" + assert trigger.poll_interval == 30 + assert trigger.fetch_logs is True + + def test_execute_complete_unexpected_status(self, operator): + event = {"status": "UNEXPECTED", "message": "Unexpected status"} + with patch.object(operator, "_delete_cluster"), pytest.raises(AirflowException) as exc_info: + operator.execute_complete({}, event) + + assert "Unexpected event status" in str(exc_info.value) + + def test_execute_complete_cleanup_on_exception(self, operator): + event = {"status": JobStatus.FAILED, "message": "Job failed"} + with patch.object(operator, "_delete_cluster") as mock_delete_cluster, pytest.raises(AirflowException): + operator.execute_complete({}, event) + + mock_delete_cluster.assert_called_once() + + def test_execute_exception_handling(self, mock_hook, context): + operator = SubmitRayJob( + task_id="test_task", + conn_id="test_conn", + entrypoint="python script.py", + runtime_env={}, + ray_cluster_yaml="cluster.yaml", + ) + + mock_hook.submit_ray_job.side_effect = Exception("Job submission failed") + + with patch.object(operator, "_setup_cluster"), patch.object( + operator, "_delete_cluster" + ) as mock_delete_cluster, pytest.raises(AirflowException) as exc_info: + operator.execute(context) + + assert "SubmitRayJob operator failed due to Job submission failed" in str(exc_info.value) + mock_delete_cluster.assert_called_once() + + def test_execute_cluster_setup_exception(self, mock_hook, context): + operator = SubmitRayJob( + task_id="test_task", + conn_id="test_conn", + entrypoint="python script.py", + runtime_env={}, + ray_cluster_yaml="cluster.yaml", + ) + + with patch.object(operator, "_setup_cluster", side_effect=Exception("Cluster setup failed")), patch.object( + operator, "_delete_cluster" + ) as mock_delete_cluster, pytest.raises(AirflowException) as exc_info: + operator.execute(context) + + assert "SubmitRayJob operator failed due to Cluster setup failed" in str(exc_info.value) + mock_delete_cluster.assert_called_once() + + def test_execute_with_wait_and_defer(self, mock_hook, context): + operator = SubmitRayJob( + task_id="test_task", + conn_id="test_conn", + entrypoint="python script.py", + runtime_env={}, + wait_for_completion=True, + poll_interval=30, + fetch_logs=True, + job_timeout_seconds=600, + ) + + mock_hook.submit_ray_job.return_value = "test_job_id" + mock_hook.get_ray_job_status.return_value = JobStatus.PENDING + + with patch.object(operator, "_setup_cluster"), patch.object(operator, "defer") as mock_defer: + operator.execute(context) + + mock_defer.assert_called_once() + args, kwargs = mock_defer.call_args + assert isinstance(kwargs["trigger"], RayJobTrigger) + assert kwargs["method_name"] == "execute_complete" + assert kwargs["timeout"].total_seconds() == 600 + + def test_execute_complete_with_cleanup(self, operator): + operator.job_id = "test_job_id" + event = {"status": JobStatus.FAILED, "message": "Job failed"} + + with patch.object(operator, "_delete_cluster") as mock_delete_cluster, pytest.raises(AirflowException): + operator.execute_complete({}, event) + + mock_delete_cluster.assert_called_once() + + def test_execute_without_wait_no_cleanup(self, mock_hook, context): + operator = SubmitRayJob( + task_id="test_task", + conn_id="test_conn", + entrypoint="python script.py", + runtime_env={}, + wait_for_completion=False, + ) + + mock_hook.submit_ray_job.return_value = "test_job_id" + + with patch.object(operator, "_setup_cluster") as mock_setup_cluster, patch.object( + operator, "_delete_cluster" + ) as mock_delete_cluster: + result = operator.execute(context) + + mock_setup_cluster.assert_called_once_with(context=context) + assert result == "test_job_id" + mock_hook.submit_ray_job.assert_called_once_with( + dashboard_url=None, + entrypoint="python script.py", + runtime_env={}, + entrypoint_num_cpus=0, + entrypoint_num_gpus=0, + entrypoint_memory=0, + entrypoint_resources=None, + ) + mock_delete_cluster.assert_not_called() diff --git a/tests/triggers/test_ray_triggers.py b/tests/triggers/test_ray_triggers.py index b553d03..f82b521 100644 --- a/tests/triggers/test_ray_triggers.py +++ b/tests/triggers/test_ray_triggers.py @@ -1,4 +1,5 @@ -from unittest.mock import patch +import logging +from unittest.mock import AsyncMock, call, patch import pytest from airflow.triggers.base import TriggerEvent @@ -14,6 +15,8 @@ def trigger(self): job_id="test_job_id", conn_id="test_conn", xcom_dashboard_url="http://test-dashboard.com", + ray_cluster_yaml="test.yaml", + gpu_device_plugin_yaml="nvidia.yaml", poll_interval=1, fetch_logs=True, ) @@ -24,7 +27,14 @@ def trigger(self): async def test_run_no_job_id(self, mock_hook, mock_is_terminal): mock_is_terminal.return_value = True mock_hook.get_ray_job_status.return_value = JobStatus.FAILED - trigger = RayJobTrigger(job_id="", poll_interval=1, conn_id="test", xcom_dashboard_url="test") + trigger = RayJobTrigger( + job_id="", + poll_interval=1, + conn_id="test", + xcom_dashboard_url="test", + ray_cluster_yaml="test.yaml", + gpu_device_plugin_yaml="nvidia.yaml", + ) generator = trigger.run() event = await generator.asend(None) assert event == TriggerEvent( @@ -37,7 +47,14 @@ async def test_run_no_job_id(self, mock_hook, mock_is_terminal): async def test_run_job_succeeded(self, mock_hook, mock_is_terminal): mock_is_terminal.side_effect = [False, True] mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED - trigger = RayJobTrigger(job_id="test_job_id", poll_interval=1, conn_id="test", xcom_dashboard_url="test") + trigger = RayJobTrigger( + job_id="test_job_id", + poll_interval=1, + conn_id="test", + xcom_dashboard_url="test", + ray_cluster_yaml="test.yaml", + gpu_device_plugin_yaml="nvidia.yaml", + ) generator = trigger.run() event = await generator.asend(None) assert event == TriggerEvent( @@ -105,23 +122,6 @@ async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_is } ) - @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") - async def test_run_with_exception(self, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = Exception("Test exception") - - generator = trigger.run() - event = await generator.asend(None) - - assert event == TriggerEvent( - { - "status": str(JobStatus.FAILED), - "message": "Test exception", - "job_id": "test_job_id", - } - ) - @pytest.mark.asyncio @patch("ray_provider.triggers.ray.RayJobTrigger.hook") async def test_stream_logs(self, mock_hook, trigger): @@ -149,6 +149,8 @@ def test_serialize(self, trigger): "job_id": "test_job_id", "conn_id": "test_conn", "xcom_dashboard_url": "http://test-dashboard.com", + "ray_cluster_yaml": "test.yaml", + "gpu_device_plugin_yaml": "nvidia.yaml", "fetch_logs": True, "poll_interval": 1, }, @@ -166,3 +168,74 @@ async def test_is_terminal_state(self, mock_hook, trigger): assert not trigger._is_terminal_state() assert not trigger._is_terminal_state() assert trigger._is_terminal_state() + + @pytest.mark.asyncio + @patch.object(RayJobTrigger, "hook") + @patch.object(logging.Logger, "info") + async def test_cleanup_with_cluster_yaml(self, mock_log_info, mock_hook, trigger): + await trigger.cleanup() + + mock_log_info.assert_has_calls( + [ + call("Attempting to delete Ray cluster using YAML: test.yaml"), + call("Ray cluster deletion process completed"), + ] + ) + mock_hook.delete_ray_cluster.assert_called_once_with("test.yaml", "nvidia.yaml") + + @pytest.mark.asyncio + @patch.object(logging.Logger, "info") + async def test_cleanup_without_cluster_yaml(self, mock_log_info): + trigger = RayJobTrigger( + job_id="test_job_id", + conn_id="test_conn", + xcom_dashboard_url="http://test-dashboard.com", + ray_cluster_yaml=None, + gpu_device_plugin_yaml="nvidia.yaml", + poll_interval=1, + fetch_logs=True, + ) + + await trigger.cleanup() + + mock_log_info.assert_called_once_with("No Ray cluster YAML provided, skipping cluster deletion") + + @pytest.mark.asyncio + @patch.object(RayJobTrigger, "hook") + @patch.object(logging.Logger, "error") + async def test_cleanup_with_exception(self, mock_log_error, mock_hook, trigger): + mock_hook.delete_ray_cluster.side_effect = Exception("Test exception") + + await trigger.cleanup() + + mock_log_error.assert_called_once_with("Unexpected error during cleanup: Test exception") + + @pytest.mark.asyncio + @patch("asyncio.sleep", new_callable=AsyncMock) + @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") + async def test_poll_status(self, mock_is_terminal, mock_sleep, trigger): + mock_is_terminal.side_effect = [False, False, True] + + await trigger._poll_status() + + assert mock_sleep.call_count == 2 + mock_sleep.assert_called_with(1) + + @pytest.mark.asyncio + @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.ray.RayJobTrigger.cleanup") + async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_is_terminal, trigger): + mock_is_terminal.side_effect = Exception("Test exception") + + generator = trigger.run() + event = await generator.asend(None) + + assert event == TriggerEvent( + { + "status": str(JobStatus.FAILED), + "message": "Test exception", + "job_id": "test_job_id", + } + ) + mock_cleanup.assert_called_once()