Skip to content

Commit

Permalink
[spark] Automatically shut down ray on spark cluster if user does not…
Browse files Browse the repository at this point in the history
… execute commands on databricks notebook for a long time (ray-project#31962)

Databricks Runtime provides an API:
dbutils.entry_point.getIdleTimeMillisSinceLastNotebookExecution() that returns elapsed milliseconds since last databricks notebook code execution.
This PR code calls this interface to monitor notebook activity and shut down Ray cluster on timeout.

Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 authored and clarence-wu committed Jan 31, 2023
1 parent e86b363 commit 1f73347
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 37 deletions.
1 change: 1 addition & 0 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ py_test_module_list(
"spark/test_GPU.py",
"spark/test_multicores_per_task.py",
"spark/test_utils.py",
"spark/test_databricks_hook.py",
],
size = "large",
tags = ["exclusive", "spark_plugin_tests", "team:serverless"],
Expand Down
16 changes: 15 additions & 1 deletion python/ray/tests/spark/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class RayOnSparkCPUClusterTestBase(ABC):
@classmethod
def teardown_class(cls):
time.sleep(10) # Wait all background spark job canceled.
os.environ.pop("SPARK_WORKER_CORES", None)
cls.spark.stop()

@staticmethod
Expand All @@ -70,6 +71,11 @@ def test_cpu_allocation(self):
self.num_cpus_per_spark_task * 2,
MAX_NUM_WORKER_NODES,
),
(
self.max_spark_tasks // 2,
self.num_cpus_per_spark_task * 2,
self.max_spark_tasks // 2 + 1,
), # Test case: requesting resources exceeding all cluster resources
]:
with _setup_ray_cluster(
num_worker_nodes=num_worker_nodes_arg,
Expand All @@ -92,6 +98,12 @@ def test_public_api(self):
ray_temp_root_dir=ray_temp_root_dir,
head_node_options={"include_dashboard": True},
)

assert (
os.environ["RAY_ADDRESS"]
== ray.util.spark.cluster_init._active_ray_cluster.address
)

ray.init()

@ray.remote
Expand All @@ -104,6 +116,8 @@ def f(x):

shutdown_ray_cluster()

assert "RAY_ADDRESS" not in os.environ

time.sleep(7)
# assert temp dir is removed.
assert len(os.listdir(ray_temp_root_dir)) == 1 and os.listdir(
Expand All @@ -121,7 +135,7 @@ def f(x):
if ray.util.spark.cluster_init._active_ray_cluster is not None:
# if the test raised error and does not destroy cluster,
# destroy it here.
ray.util.spark._active_ray_cluster.shutdown()
ray.util.spark.cluster_init._active_ray_cluster.shutdown()
time.sleep(5)
shutil.rmtree(ray_temp_root_dir, ignore_errors=True)
shutil.rmtree(collect_log_to_path, ignore_errors=True)
Expand Down
80 changes: 80 additions & 0 deletions python/ray/tests/spark/test_databricks_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import sys

import pytest
import os
import time
import ray
from pyspark.sql import SparkSession
from ray.util.spark import setup_ray_cluster
import ray.util.spark.databricks_hook


pytestmark = pytest.mark.skipif(
not sys.platform.startswith("linux"),
reason="Ray on spark only supports running on Linux.",
)


class MockDbApiEntry:
def __init__(self):
self.created_time = time.time()
self.registered_job_groups = []

def getIdleTimeMillisSinceLastNotebookExecution(self):
return (time.time() - self.created_time) * 1000

def registerBackgroundSparkJobGroup(self, job_group_id):
self.registered_job_groups.append(job_group_id)


class TestDatabricksHook:
@classmethod
def setup_class(cls):
os.environ["SPARK_WORKER_CORES"] = "2"
cls.spark = (
SparkSession.builder.master("local-cluster[1, 2, 1024]")
.config("spark.task.cpus", "1")
.config("spark.task.maxFailures", "1")
.config("spark.executorEnv.RAY_ON_SPARK_WORKER_CPU_CORES", "2")
.getOrCreate()
)

@classmethod
def teardown_class(cls):
time.sleep(10) # Wait all background spark job canceled.
cls.spark.stop()
os.environ.pop("SPARK_WORKER_CORES")

def test_hook(self, monkeypatch):
monkeypatch.setattr(
"ray.util.spark.databricks_hook._DATABRICKS_DEFAULT_TMP_DIR", "/tmp"
)
monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "12.2")
monkeypatch.setenv("DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES", "0.5")
db_api_entry = MockDbApiEntry()
monkeypatch.setattr(
"ray.util.spark.databricks_hook._get_db_api_entry", lambda: db_api_entry
)
try:
setup_ray_cluster(
num_worker_nodes=2,
head_node_options={"include_dashboard": False},
)
cluster = ray.util.spark.cluster_init._active_ray_cluster
assert not cluster.is_shutdown
assert db_api_entry.registered_job_groups == [cluster.spark_job_group_id]
time.sleep(35)
assert cluster.is_shutdown
assert ray.util.spark.cluster_init._active_ray_cluster is None
finally:
if ray.util.spark.cluster_init._active_ray_cluster is not None:
# if the test raised error and does not destroy cluster,
# destroy it here.
ray.util.spark.cluster_init._active_ray_cluster.shutdown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))
70 changes: 42 additions & 28 deletions python/ray/util/spark/cluster_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ def _cancel_background_spark_job(self):
def wait_until_ready(self):
import ray

if self.background_job_exception is not None:
raise RuntimeError(
"Ray workers has exited."
) from self.background_job_exception

if self.is_shutdown:
raise RuntimeError(
"The ray cluster has been shut down or it failed to start."
Expand All @@ -125,6 +120,16 @@ def wait_until_ready(self):
last_progress_move_time = time.time()
while True:
time.sleep(_RAY_CLUSTER_STARTUP_PROGRESS_CHECKING_INTERVAL)

# Inside the waiting ready loop,
# checking `self.background_job_exception`, if it is not None,
# it means the background spark job has failed,
# in this case, raise error directly.
if self.background_job_exception is not None:
raise RuntimeError(
"Ray workers have exited."
) from self.background_job_exception

cur_alive_worker_count = (
len([node for node in ray.nodes() if node["Alive"]]) - 1
) # Minus 1 means excluding the head node.
Expand All @@ -144,12 +149,17 @@ def wait_until_ready(self):
time.time() - last_progress_move_time
> _RAY_CONNECT_CLUSTER_POLL_PROGRESS_TIMEOUT
):
if cur_alive_worker_count == 0:
raise RuntimeError(
"Current spark cluster has no resources to launch "
"Ray worker nodes."
)
_logger.warning(
"Timeout in waiting for all ray workers to start. "
"Started / Total requested: "
f"({cur_alive_worker_count} / {self.num_worker_nodes}). "
"Please check ray logs to see why some ray workers "
"failed to start."
"Current spark cluster does not have sufficient resources "
"to launch requested number of Ray worker nodes."
)
return
finally:
Expand Down Expand Up @@ -678,7 +688,7 @@ def background_job_thread_fn():
).start()

# Call hook immediately after spark job started.
start_hook.on_spark_background_job_created(spark_job_group_id)
start_hook.on_cluster_created(ray_cluster_handler)

# wait background spark task starting.
for _ in range(_BACKGROUND_JOB_STARTUP_WAIT):
Expand All @@ -699,6 +709,7 @@ def background_job_thread_fn():


_active_ray_cluster = None
_active_ray_cluster_rwlock = threading.RLock()


def _create_resource_profile(num_cpus_per_node, num_gpus_per_node):
Expand Down Expand Up @@ -1007,23 +1018,24 @@ def setup_ray_cluster(
else:
_logger.warning("\n".join(insufficient_resources))

cluster = _setup_ray_cluster(
num_worker_nodes=num_worker_nodes,
num_cpus_per_node=num_cpus_per_node,
num_gpus_per_node=num_gpus_per_node,
using_stage_scheduling=using_stage_scheduling,
heap_memory_per_node=ray_worker_node_heap_mem_bytes,
object_store_memory_per_node=ray_worker_node_object_store_mem_bytes,
head_node_options=head_node_options,
worker_node_options=worker_node_options,
ray_temp_root_dir=ray_temp_root_dir,
collect_log_to_path=collect_log_to_path,
)
cluster.wait_until_ready() # NB: this line might raise error.
with _active_ray_cluster_rwlock:
cluster = _setup_ray_cluster(
num_worker_nodes=num_worker_nodes,
num_cpus_per_node=num_cpus_per_node,
num_gpus_per_node=num_gpus_per_node,
using_stage_scheduling=using_stage_scheduling,
heap_memory_per_node=ray_worker_node_heap_mem_bytes,
object_store_memory_per_node=ray_worker_node_object_store_mem_bytes,
head_node_options=head_node_options,
worker_node_options=worker_node_options,
ray_temp_root_dir=ray_temp_root_dir,
collect_log_to_path=collect_log_to_path,
)
cluster.wait_until_ready() # NB: this line might raise error.

# If connect cluster successfully, set global _active_ray_cluster to be the started
# cluster.
_active_ray_cluster = cluster
# If connect cluster successfully, set global _active_ray_cluster to be the
# started cluster.
_active_ray_cluster = cluster
return cluster.address


Expand All @@ -1033,8 +1045,10 @@ def shutdown_ray_cluster() -> None:
Shut down the active ray cluster.
"""
global _active_ray_cluster
if _active_ray_cluster is None:
raise RuntimeError("No active ray cluster to shut down.")

_active_ray_cluster.shutdown()
_active_ray_cluster = None
with _active_ray_cluster_rwlock:
if _active_ray_cluster is None:
raise RuntimeError("No active ray cluster to shut down.")

_active_ray_cluster.shutdown()
_active_ray_cluster = None
99 changes: 93 additions & 6 deletions python/ray/util/spark/databricks_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os

from .start_hook_base import RayOnSparkStartHook
from .utils import get_spark_session
import logging
import threading
import time

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,21 +58,104 @@ def display_databricks_driver_proxy_url(spark_context, port, title):
)


DATABRICKS_AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS = 3
DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES = (
"DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES"
)


def _get_db_api_entry():
"""
Get databricks API entry point.
"""
return get_dbutils().entry_point


_DATABRICKS_DEFAULT_TMP_DIR = "/local_disk0/tmp"


class DefaultDatabricksRayOnSparkStartHook(RayOnSparkStartHook):
def get_default_temp_dir(self):
return "/local_disk0/tmp"
return _DATABRICKS_DEFAULT_TMP_DIR

def on_ray_dashboard_created(self, port):
display_databricks_driver_proxy_url(
get_spark_session().sparkContext, port, "Ray Cluster Dashboard"
)

def on_spark_background_job_created(self, job_group_id):
def on_cluster_created(self, ray_cluster_handler):
db_api_entry = _get_db_api_entry()
try:
get_dbutils().entry_point.registerBackgroundSparkJobGroup(job_group_id)
db_api_entry.registerBackgroundSparkJobGroup(
ray_cluster_handler.spark_job_group_id
)
except Exception:
_logger.warning(
"Register ray cluster spark job as background job failed. You need to "
"manually call `ray_cluster_on_spark.shutdown()` before detaching "
"your databricks python REPL."
"Registering Ray cluster spark job as background job failed. "
"You need to manually call `ray.util.spark.shutdown_ray_cluster()` "
"before detaching your databricks notebook."
)

auto_shutdown_minutes = float(
os.environ.get(DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES, "30")
)
if auto_shutdown_minutes == 0:
_logger.info(
"The Ray cluster will keep running until you manually detach the "
"databricks notebook or call "
"`ray.util.spark.shutdown_ray_cluster()`."
)
return
if auto_shutdown_minutes < 0:
raise ValueError(
"You must set "
f"'{DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES}' "
"to a value >= 0."
)

try:
db_api_entry.getIdleTimeMillisSinceLastNotebookExecution()
except Exception:
_logger.warning(
"Databricks `getIdleTimeMillisSinceLastNotebookExecution` API "
"is unavailable, it is probably because that "
"your current Databricks Runtime version does not support API "
"`getIdleTimeMillisSinceLastNotebookExecution`, we cannot "
"automatically shut down Ray cluster when databricks notebook "
"is inactive, you need to manually detach databricks notebook "
"or call `ray.util.spark.shutdown_ray_cluster()` to shut down "
"Ray cluster on spark."
)
return

_logger.info(
"The Ray cluster will be shut down automatically if you don't run "
"commands on the databricks notebook for "
f"{auto_shutdown_minutes} minutes. You can change the "
"automatically shutdown minutes by setting "
f"'{DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES}' environment "
"variable, setting it to 0 means that the Ray cluster keeps running "
"until you manually call `ray.util.spark.shutdown_ray_cluster()` or "
"detach databricks notebook."
)

def auto_shutdown_watcher():
auto_shutdown_millis = auto_shutdown_minutes * 60 * 1000
while True:
if ray_cluster_handler.is_shutdown:
# The cluster is shut down. The watcher thread exits.
return

idle_time = db_api_entry.getIdleTimeMillisSinceLastNotebookExecution()

if idle_time > auto_shutdown_millis:
from ray.util.spark import cluster_init

with cluster_init._active_ray_cluster_rwlock:
if ray_cluster_handler is cluster_init._active_ray_cluster:
cluster_init.shutdown_ray_cluster()
return

time.sleep(DATABRICKS_AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS)

threading.Thread(target=auto_shutdown_watcher, daemon=True).start()
2 changes: 1 addition & 1 deletion python/ray/util/spark/start_hook_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ def get_default_temp_dir(self):
def on_ray_dashboard_created(self, port):
pass

def on_spark_background_job_created(self, job_group):
def on_cluster_created(self, ray_cluster_handler):
pass
2 changes: 1 addition & 1 deletion python/requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pygame==2.1.2; python_version < '3.11'
Pygments==2.13.0
pymongo==4.3.2
# TODO: Replace this with pyspark==3.4 once it is released.
https://ml-team-public-read.s3.us-west-2.amazonaws.com/pyspark-3.4.0.dev0.tar.gz
https://ml-team-public-read.s3.us-west-2.amazonaws.com/spark-pkgs/pyspark-3.4.0.dev0-0cb0fa313979e1b82ddd711a05d8c4e78cf6c9f5.tar.gz
pytest==7.0.1
pytest-asyncio==0.16.0
pytest-rerunfailures==10.2
Expand Down

0 comments on commit 1f73347

Please sign in to comment.