Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spark] Automatically shut down ray on spark cluster if user does not execute commands on databricks notebook for a long time #31962

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ py_test_module_list(
"spark/test_GPU.py",
"spark/test_multicores_per_task.py",
"spark/test_utils.py",
"spark/test_databricks_hook.py",
],
size = "large",
tags = ["exclusive", "spark_plugin_tests", "team:serverless"],
Expand Down
16 changes: 15 additions & 1 deletion python/ray/tests/spark/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class RayOnSparkCPUClusterTestBase(ABC):
@classmethod
def teardown_class(cls):
time.sleep(10) # Wait all background spark job canceled.
os.environ.pop("SPARK_WORKER_CORES", None)
cls.spark.stop()

@staticmethod
Expand All @@ -70,6 +71,11 @@ def test_cpu_allocation(self):
self.num_cpus_per_spark_task * 2,
MAX_NUM_WORKER_NODES,
),
(
self.max_spark_tasks // 2,
self.num_cpus_per_spark_task * 2,
self.max_spark_tasks // 2 + 1,
), # Test case: requesting resources exceeding all cluster resources
]:
with _setup_ray_cluster(
num_worker_nodes=num_worker_nodes_arg,
Expand All @@ -92,6 +98,12 @@ def test_public_api(self):
ray_temp_root_dir=ray_temp_root_dir,
head_node_options={"include_dashboard": True},
)

assert (
os.environ["RAY_ADDRESS"]
== ray.util.spark.cluster_init._active_ray_cluster.address
)

ray.init()

@ray.remote
Expand All @@ -104,6 +116,8 @@ def f(x):

shutdown_ray_cluster()

assert "RAY_ADDRESS" not in os.environ

time.sleep(7)
# assert temp dir is removed.
assert len(os.listdir(ray_temp_root_dir)) == 1 and os.listdir(
Expand All @@ -121,7 +135,7 @@ def f(x):
if ray.util.spark.cluster_init._active_ray_cluster is not None:
# if the test raised error and does not destroy cluster,
# destroy it here.
ray.util.spark._active_ray_cluster.shutdown()
ray.util.spark.cluster_init._active_ray_cluster.shutdown()
time.sleep(5)
shutil.rmtree(ray_temp_root_dir, ignore_errors=True)
shutil.rmtree(collect_log_to_path, ignore_errors=True)
Expand Down
80 changes: 80 additions & 0 deletions python/ray/tests/spark/test_databricks_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import sys

import pytest
import os
import time
import ray
from pyspark.sql import SparkSession
from ray.util.spark import setup_ray_cluster
import ray.util.spark.databricks_hook


pytestmark = pytest.mark.skipif(
not sys.platform.startswith("linux"),
reason="Ray on spark only supports running on Linux.",
)


class MockDbApiEntry:
def __init__(self):
self.created_time = time.time()
self.registered_job_groups = []

def getIdleTimeMillisSinceLastNotebookExecution(self):
return (time.time() - self.created_time) * 1000

def registerBackgroundSparkJobGroup(self, job_group_id):
self.registered_job_groups.append(job_group_id)


class TestDatabricksHook:
@classmethod
def setup_class(cls):
os.environ["SPARK_WORKER_CORES"] = "2"
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
cls.spark = (
SparkSession.builder.master("local-cluster[1, 2, 1024]")
.config("spark.task.cpus", "1")
.config("spark.task.maxFailures", "1")
.config("spark.executorEnv.RAY_ON_SPARK_WORKER_CPU_CORES", "2")
.getOrCreate()
)

@classmethod
def teardown_class(cls):
time.sleep(10) # Wait all background spark job canceled.
cls.spark.stop()
os.environ.pop("SPARK_WORKER_CORES")

def test_hook(self, monkeypatch):
monkeypatch.setattr(
"ray.util.spark.databricks_hook._DATABRICKS_DEFAULT_TMP_DIR", "/tmp"
)
monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "12.2")
monkeypatch.setenv("DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES", "0.5")
db_api_entry = MockDbApiEntry()
monkeypatch.setattr(
"ray.util.spark.databricks_hook._get_db_api_entry", lambda: db_api_entry
)
try:
setup_ray_cluster(
num_worker_nodes=2,
head_node_options={"include_dashboard": False},
)
cluster = ray.util.spark.cluster_init._active_ray_cluster
assert not cluster.is_shutdown
assert db_api_entry.registered_job_groups == [cluster.spark_job_group_id]
time.sleep(35)
assert cluster.is_shutdown
assert ray.util.spark.cluster_init._active_ray_cluster is None
finally:
if ray.util.spark.cluster_init._active_ray_cluster is not None:
# if the test raised error and does not destroy cluster,
# destroy it here.
ray.util.spark.cluster_init._active_ray_cluster.shutdown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))
70 changes: 42 additions & 28 deletions python/ray/util/spark/cluster_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ def _cancel_background_spark_job(self):
def wait_until_ready(self):
import ray

if self.background_job_exception is not None:
raise RuntimeError(
"Ray workers has exited."
) from self.background_job_exception

if self.is_shutdown:
raise RuntimeError(
"The ray cluster has been shut down or it failed to start."
Expand All @@ -125,6 +120,16 @@ def wait_until_ready(self):
last_progress_move_time = time.time()
while True:
time.sleep(_RAY_CLUSTER_STARTUP_PROGRESS_CHECKING_INTERVAL)

# Inside the waiting ready loop,
# checking `self.background_job_exception`, if it is not None,
# it means the background spark job has failed,
# in this case, raise error directly.
if self.background_job_exception is not None:
raise RuntimeError(
"Ray workers have exited."
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
) from self.background_job_exception

cur_alive_worker_count = (
len([node for node in ray.nodes() if node["Alive"]]) - 1
) # Minus 1 means excluding the head node.
Expand All @@ -144,12 +149,17 @@ def wait_until_ready(self):
time.time() - last_progress_move_time
> _RAY_CONNECT_CLUSTER_POLL_PROGRESS_TIMEOUT
):
if cur_alive_worker_count == 0:
raise RuntimeError(
"Current spark cluster has no resources to launch "
"Ray worker nodes."
)
_logger.warning(
"Timeout in waiting for all ray workers to start. "
"Started / Total requested: "
f"({cur_alive_worker_count} / {self.num_worker_nodes}). "
"Please check ray logs to see why some ray workers "
"failed to start."
"Current spark cluster does not have sufficient resources "
"to launch requested number of Ray worker nodes."
)
return
finally:
Expand Down Expand Up @@ -678,7 +688,7 @@ def background_job_thread_fn():
).start()

# Call hook immediately after spark job started.
start_hook.on_spark_background_job_created(spark_job_group_id)
start_hook.on_cluster_created(ray_cluster_handler)

# wait background spark task starting.
for _ in range(_BACKGROUND_JOB_STARTUP_WAIT):
Expand All @@ -699,6 +709,7 @@ def background_job_thread_fn():


_active_ray_cluster = None
_active_ray_cluster_rwlock = threading.RLock()


def _create_resource_profile(num_cpus_per_node, num_gpus_per_node):
Expand Down Expand Up @@ -1007,23 +1018,24 @@ def setup_ray_cluster(
else:
_logger.warning("\n".join(insufficient_resources))

cluster = _setup_ray_cluster(
num_worker_nodes=num_worker_nodes,
num_cpus_per_node=num_cpus_per_node,
num_gpus_per_node=num_gpus_per_node,
using_stage_scheduling=using_stage_scheduling,
heap_memory_per_node=ray_worker_node_heap_mem_bytes,
object_store_memory_per_node=ray_worker_node_object_store_mem_bytes,
head_node_options=head_node_options,
worker_node_options=worker_node_options,
ray_temp_root_dir=ray_temp_root_dir,
collect_log_to_path=collect_log_to_path,
)
cluster.wait_until_ready() # NB: this line might raise error.
with _active_ray_cluster_rwlock:
cluster = _setup_ray_cluster(
num_worker_nodes=num_worker_nodes,
num_cpus_per_node=num_cpus_per_node,
num_gpus_per_node=num_gpus_per_node,
using_stage_scheduling=using_stage_scheduling,
heap_memory_per_node=ray_worker_node_heap_mem_bytes,
object_store_memory_per_node=ray_worker_node_object_store_mem_bytes,
head_node_options=head_node_options,
worker_node_options=worker_node_options,
ray_temp_root_dir=ray_temp_root_dir,
collect_log_to_path=collect_log_to_path,
)
cluster.wait_until_ready() # NB: this line might raise error.

# If connect cluster successfully, set global _active_ray_cluster to be the started
# cluster.
_active_ray_cluster = cluster
# If connect cluster successfully, set global _active_ray_cluster to be the
# started cluster.
_active_ray_cluster = cluster
return cluster.address


Expand All @@ -1033,8 +1045,10 @@ def shutdown_ray_cluster() -> None:
Shut down the active ray cluster.
"""
global _active_ray_cluster
if _active_ray_cluster is None:
raise RuntimeError("No active ray cluster to shut down.")

_active_ray_cluster.shutdown()
_active_ray_cluster = None
with _active_ray_cluster_rwlock:
if _active_ray_cluster is None:
raise RuntimeError("No active ray cluster to shut down.")

_active_ray_cluster.shutdown()
_active_ray_cluster = None
99 changes: 93 additions & 6 deletions python/ray/util/spark/databricks_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os

from .start_hook_base import RayOnSparkStartHook
from .utils import get_spark_session
import logging
import threading
import time

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,21 +58,104 @@ def display_databricks_driver_proxy_url(spark_context, port, title):
)


DATABRICKS_AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS = 3
DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES = (
"DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES"
)


def _get_db_api_entry():
"""
Get databricks API entry point.
"""
return get_dbutils().entry_point


_DATABRICKS_DEFAULT_TMP_DIR = "/local_disk0/tmp"


class DefaultDatabricksRayOnSparkStartHook(RayOnSparkStartHook):
def get_default_temp_dir(self):
return "/local_disk0/tmp"
return _DATABRICKS_DEFAULT_TMP_DIR

def on_ray_dashboard_created(self, port):
display_databricks_driver_proxy_url(
get_spark_session().sparkContext, port, "Ray Cluster Dashboard"
)

def on_spark_background_job_created(self, job_group_id):
def on_cluster_created(self, ray_cluster_handler):
db_api_entry = _get_db_api_entry()
try:
get_dbutils().entry_point.registerBackgroundSparkJobGroup(job_group_id)
db_api_entry.registerBackgroundSparkJobGroup(
ray_cluster_handler.spark_job_group_id
)
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
except Exception:
_logger.warning(
"Register ray cluster spark job as background job failed. You need to "
"manually call `ray_cluster_on_spark.shutdown()` before detaching "
"your databricks python REPL."
"Registering Ray cluster spark job as background job failed. "
"You need to manually call `ray.util.spark.shutdown_ray_cluster()` "
"before detaching your databricks notebook."
)

auto_shutdown_minutes = float(
os.environ.get(DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES, "30")
)
if auto_shutdown_minutes == 0:
_logger.info(
"The Ray cluster will keep running until you manually detach the "
"databricks notebook or call "
"`ray.util.spark.shutdown_ray_cluster()`."
)
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
return
if auto_shutdown_minutes < 0:
raise ValueError(
"You must set "
f"'{DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES}' "
"to a value >= 0."
)

try:
db_api_entry.getIdleTimeMillisSinceLastNotebookExecution()
except Exception:
_logger.warning(
"Databricks `getIdleTimeMillisSinceLastNotebookExecution` API "
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
"is unavailable, it is probably because that "
"your current Databricks Runtime version does not support API "
"`getIdleTimeMillisSinceLastNotebookExecution`, we cannot "
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
"automatically shut down Ray cluster when databricks notebook "
"is inactive, you need to manually detach databricks notebook "
"or call `ray.util.spark.shutdown_ray_cluster()` to shut down "
"Ray cluster on spark."
)
return

_logger.info(
"The Ray cluster will be shut down automatically if you don't run "
"commands on the databricks notebook for "
f"{auto_shutdown_minutes} minutes. You can change the "
"automatically shutdown minutes by setting "
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
f"'{DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES}' environment "
"variable, setting it to 0 means that the Ray cluster keeps running "
"until you manually call `ray.util.spark.shutdown_ray_cluster()` or "
"detach databricks notebook."
)

def auto_shutdown_watcher():
auto_shutdown_millis = auto_shutdown_minutes * 60 * 1000
while True:
if ray_cluster_handler.is_shutdown:
# The cluster is shut down. The watcher thread exits.
return

idle_time = db_api_entry.getIdleTimeMillisSinceLastNotebookExecution()

if idle_time > auto_shutdown_millis:
from ray.util.spark import cluster_init

with cluster_init._active_ray_cluster_rwlock:
if ray_cluster_handler is cluster_init._active_ray_cluster:
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
cluster_init.shutdown_ray_cluster()
return

time.sleep(DATABRICKS_AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS)

threading.Thread(target=auto_shutdown_watcher, daemon=True).start()
2 changes: 1 addition & 1 deletion python/ray/util/spark/start_hook_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ def get_default_temp_dir(self):
def on_ray_dashboard_created(self, port):
pass

def on_spark_background_job_created(self, job_group):
def on_cluster_created(self, ray_cluster_handler):
pass
2 changes: 1 addition & 1 deletion python/requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pygame==2.1.2; python_version < '3.11'
Pygments==2.13.0
pymongo==4.3.2
# TODO: Replace this with pyspark==3.4 once it is released.
https://ml-team-public-read.s3.us-west-2.amazonaws.com/pyspark-3.4.0.dev0.tar.gz
https://ml-team-public-read.s3.us-west-2.amazonaws.com/spark-pkgs/pyspark-3.4.0.dev0-0cb0fa313979e1b82ddd711a05d8c4e78cf6c9f5.tar.gz
pytest==7.0.1
pytest-asyncio==0.16.0
pytest-rerunfailures==10.2
Expand Down