Skip to content

Commit

Permalink
Merge branch 'main' into decorator_bug_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana authored Sep 27, 2024
2 parents a8629e6 + a900d43 commit cbf93ee
Show file tree
Hide file tree
Showing 8 changed files with 397 additions and 85 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ name: test
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
pull_request_target:
branches: ["main"]

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
Expand Down
21 changes: 10 additions & 11 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Welcome to Ray provider documentation!
Welcome to the Ray provider documentation!
======================================

.. toctree::
Expand All @@ -12,14 +12,14 @@ Welcome to Ray provider documentation!
API Reference <api/ray_provider>
Contributing <CONTRIBUTING>

This repository provides tools for integrating `Apache Airflow®`_ with Ray, enabling the orchestration of Ray jobs within Airflow DAGs. It includes a decorator, two operators, and one trigger designed to efficiently manage and monitor Ray jobs and services.
This repository contains modules for integrating `Apache Airflow®`_ with Ray, enabling the orchestration of Ray jobs from Airflow DAGs. It includes a decorator, two operators, and one trigger designed to efficiently manage and monitor Ray jobs and services.

Benefits of using this provider include:

- **Integration**: Incorporate Ray jobs into Airflow DAGs for unified workflow management.
- **Distributed computing**: Use Ray's distributed capabilities within Airflow pipelines for scalable ETL, LLM fine-tuning etc.
- **Monitoring**: Track Ray job progress through Airflow's user interface.
- **Dependency management**: Define and manage dependencies between Ray jobs and other tasks in DAGs.
- **Dependency management**: Define and manage dependencies between Ray jobs and other tasks in Airflow DAGs.
- **Resource allocation**: Run Ray jobs alongside other task types within a single pipeline.

.. _Apache Airflow®: https://airflow.apache.org/
Expand All @@ -39,15 +39,15 @@ Quickstart

See the :doc:`Getting Started <getting_started/setup>` page for detailed instructions on how to begin using the provider.

What is the Ray provider?
Why use Airflow with Ray?
-------------------------

Enterprise data value extraction involves two crucial components:
Value creation from data in an enterprise environment involves two crucial components:

- Data Engineering
- Data Science/ML/AI
- Data Engineering (ETL/ELT/Infrastructure Management)
- Data Science (ML/AI)

While Airflow excels at data engineering tasks through its extensive plugin ecosystem, it generally relies on external systems when dealing with large-scale ETL(100s GB to PB scale) or AI tasks such as fine-tuning & deploying LLMs etc.
While Airflow excels at orchestrating both, data engineering and data science related tasks through its extensive provider ecosystem, it often relies on external systems when dealing with large-scale (100s GB to PB scale) data and compute (GPU) requirements, such as fine-tuning & deploying LLMs etc.

Ray is a particularly powerful platform for handling large scale computations and this provider makes it very straightforward to orchestrate Ray jobs from Airflow.

Expand All @@ -57,12 +57,11 @@ Ray is a particularly powerful platform for handling large scale computations an
:width: 499
:height: 561

The architecture diagram above shows how we can deploy both Airflow & Ray on a Kubernetes cluster for elastic compute.

The architecture diagram above shows that we can run both, Airflow and Ray side by side on Kubernetes to leverage the best of both worlds. Airflow can be used to orchestrate Ray jobs and services, while Ray can be used to run distributed computations.

Use Cases
^^^^^^^^^
- **Scalable ETL**: Orchestrate and monitor Ray jobs on on-demand compute clusters using the Ray Data library. These operations could be custom Python code or ML model inference.
- **Scalable ETL**: Orchestrate and monitor Ray jobs to perform distributed ETL for heavy data loads on on-demand compute clusters using the Ray Data library.
- **Model Training**: Schedule model training or fine-tuning jobs on flexible cadences (daily/weekly/monthly). Benefits include:

* Optimize resource utilization by scheduling Ray jobs during cost-effective periods
Expand Down
4 changes: 2 additions & 2 deletions ray_provider/decorators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob):

template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs")

def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None:
def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None:
job_timeout_seconds: int = config.get("job_timeout_seconds", 600)

self.config = config
self.kwargs = kwargs

super().__init__(conn_id="", entrypoint="python script.py", runtime_env={}, **kwargs)

def get_config(self, context: Context, config: Callable[..., dict[str, Any]], **kwargs: Any) -> dict[str, Any]:
Expand Down
98 changes: 52 additions & 46 deletions ray_provider/operators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class SubmitRayJob(BaseOperator):
:param gpu_device_plugin_yaml: URL or path to the GPU device plugin YAML file. Defaults to NVIDIA's plugin.
:param fetch_logs: Whether to fetch logs from the Ray job. Defaults to True.
:param wait_for_completion: Whether to wait for the job to complete before marking the task as finished. Defaults to True.
:param job_timeout_seconds: Maximum time to wait for job completion in seconds. Defaults to 600 seconds.
:param job_timeout_seconds: Maximum time to wait for job completion in seconds. Defaults to 600 seconds. Set to 0 if you want the job to run indefinitely without timeouts.
:param poll_interval: Interval between job status checks in seconds. Defaults to 60 seconds.
:param xcom_task_key: XCom key to retrieve the dashboard URL. Defaults to None.
"""
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(
self.gpu_device_plugin_yaml = gpu_device_plugin_yaml
self.fetch_logs = fetch_logs
self.wait_for_completion = wait_for_completion
self.job_timeout_seconds = job_timeout_seconds
self.job_timeout_seconds = timedelta(seconds=job_timeout_seconds) if job_timeout_seconds > 0 else None
self.poll_interval = poll_interval
self.xcom_task_key = xcom_task_key
self.dashboard_url: str | None = None
Expand All @@ -185,8 +185,7 @@ def on_kill(self) -> None:
if hasattr(self, "hook") and self.job_id:
self.log.info(f"Deleting Ray job {self.job_id} due to task kill.")
self.hook.delete_ray_job(self.dashboard_url, self.job_id)
if self.ray_cluster_yaml:
self._delete_cluster()
self._delete_cluster()

@cached_property
def hook(self) -> PodOperatorHookProtocol:
Expand Down Expand Up @@ -262,48 +261,55 @@ def execute(self, context: Context) -> str:
:raises AirflowException: If the job fails, is cancelled, or reaches an unexpected state.
"""

self._setup_cluster(context=context)

self.dashboard_url = self._get_dashboard_url(context)

self.job_id = self.hook.submit_ray_job(
dashboard_url=self.dashboard_url,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
entrypoint_num_cpus=self.num_cpus,
entrypoint_num_gpus=self.num_gpus,
entrypoint_memory=self.memory,
entrypoint_resources=self.ray_resources,
)
self.log.info(f"Ray job submitted with id: {self.job_id}")

if self.wait_for_completion:
current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in self.terminal_states:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
poll_interval=self.poll_interval,
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.job_timeout_seconds),
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
elif current_status == JobStatus.FAILED:
raise AirflowException(f"Job failed:\n{self.job_id}")
elif current_status == JobStatus.STOPPED:
raise AirflowException(f"Job was cancelled:\n{self.job_id}")
else:
raise AirflowException(f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`")

return self.job_id
try:
self._setup_cluster(context=context)

self.dashboard_url = self._get_dashboard_url(context)

self.job_id = self.hook.submit_ray_job(
dashboard_url=self.dashboard_url,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
entrypoint_num_cpus=self.num_cpus,
entrypoint_num_gpus=self.num_gpus,
entrypoint_memory=self.memory,
entrypoint_resources=self.ray_resources,
)
self.log.info(f"Ray job submitted with id: {self.job_id}")

if self.wait_for_completion:
current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in self.terminal_states:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
ray_cluster_yaml=self.ray_cluster_yaml,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
poll_interval=self.poll_interval,
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=self.job_timeout_seconds,
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
elif current_status == JobStatus.FAILED:
raise AirflowException(f"Job failed:\n{self.job_id}")
elif current_status == JobStatus.STOPPED:
raise AirflowException(f"Job was cancelled:\n{self.job_id}")
else:
raise AirflowException(
f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`"
)
return self.job_id
except Exception as e:
self._delete_cluster()
raise AirflowException(f"SubmitRayJob operator failed due to {e}. Cleaning up resources...")

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Expand Down
30 changes: 30 additions & 0 deletions ray_provider/triggers/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ def __init__(
job_id: str,
conn_id: str,
xcom_dashboard_url: str | None,
ray_cluster_yaml: str | None,
gpu_device_plugin_yaml: str,
poll_interval: int = 30,
fetch_logs: bool = True,
):
super().__init__() # type: ignore[no-untyped-call]
self.job_id = job_id
self.conn_id = conn_id
self.dashboard_url = xcom_dashboard_url
self.ray_cluster_yaml = ray_cluster_yaml
self.gpu_device_plugin_yaml = gpu_device_plugin_yaml
self.fetch_logs = fetch_logs
self.poll_interval = poll_interval

Expand All @@ -52,6 +56,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"job_id": self.job_id,
"conn_id": self.conn_id,
"xcom_dashboard_url": self.dashboard_url,
"ray_cluster_yaml": self.ray_cluster_yaml,
"gpu_device_plugin_yaml": self.gpu_device_plugin_yaml,
"fetch_logs": self.fetch_logs,
"poll_interval": self.poll_interval,
},
Expand All @@ -66,6 +72,28 @@ def hook(self) -> RayHook:
"""
return RayHook(conn_id=self.conn_id)

async def cleanup(self) -> None:
"""
Cleanup method to ensure resources are properly deleted. This will be called when the trigger encounters an exception.
Example scenario: A job is submitted using the @ray.task decorator with a Ray specification. After the cluster is started
and the job is submitted, the trigger begins tracking its progress. However, if the job is stopped through the UI at this stage, the cluster
resources are not deleted.
"""
try:
if self.ray_cluster_yaml:
self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}")
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml
)
self.log.info("Ray cluster deletion process completed")
else:
self.log.info("No Ray cluster YAML provided, skipping cluster deletion")
except Exception as e:
self.log.error(f"Unexpected error during cleanup: {str(e)}")

async def _poll_status(self) -> None:
while not self._is_terminal_state():
await asyncio.sleep(self.poll_interval)
Expand Down Expand Up @@ -109,6 +137,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
}
)
except Exception as e:
self.log.error(f"Error occurred: {str(e)}")
await self.cleanup()
yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id})

def _is_terminal_state(self) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions tests/decorators/test_ray_decorators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -38,7 +39,7 @@ def dummy_callable():
assert operator.ray_resources == {"custom_resource": 1}
assert operator.fetch_logs == True
assert operator.wait_for_completion == True
assert operator.job_timeout_seconds == 300
assert operator.job_timeout_seconds == timedelta(seconds=300)
assert operator.poll_interval == 30
assert operator.xcom_task_key == "ray_result"

Expand All @@ -59,7 +60,7 @@ def dummy_callable():
assert operator.ray_resources is None
assert operator.fetch_logs == True
assert operator.wait_for_completion == True
assert operator.job_timeout_seconds == 600
assert operator.job_timeout_seconds == timedelta(seconds=600)
assert operator.poll_interval == 60
assert operator.xcom_task_key is None

Expand Down
Loading

0 comments on commit cbf93ee

Please sign in to comment.