From 1f73347f75939961fd45b58830aec6dcf563ecdb Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 31 Jan 2023 19:36:55 +0800 Subject: [PATCH] [spark] Automatically shut down ray on spark cluster if user does not execute commands on databricks notebook for a long time (#31962) Databricks Runtime provides an API: dbutils.entry_point.getIdleTimeMillisSinceLastNotebookExecution() that returns elapsed milliseconds since last databricks notebook code execution. This PR code calls this interface to monitor notebook activity and shut down Ray cluster on timeout. Signed-off-by: Weichen Xu --- python/ray/tests/BUILD | 1 + python/ray/tests/spark/test_basic.py | 16 ++- .../ray/tests/spark/test_databricks_hook.py | 80 +++++++++++++++ python/ray/util/spark/cluster_init.py | 70 +++++++------ python/ray/util/spark/databricks_hook.py | 99 +++++++++++++++++-- python/ray/util/spark/start_hook_base.py | 2 +- python/requirements_test.txt | 2 +- 7 files changed, 233 insertions(+), 37 deletions(-) create mode 100644 python/ray/tests/spark/test_databricks_hook.py diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 62ea4bfeffec..5fa316f00fec 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -535,6 +535,7 @@ py_test_module_list( "spark/test_GPU.py", "spark/test_multicores_per_task.py", "spark/test_utils.py", + "spark/test_databricks_hook.py", ], size = "large", tags = ["exclusive", "spark_plugin_tests", "team:serverless"], diff --git a/python/ray/tests/spark/test_basic.py b/python/ray/tests/spark/test_basic.py index dac18bab1604..325fb3e4d96a 100644 --- a/python/ray/tests/spark/test_basic.py +++ b/python/ray/tests/spark/test_basic.py @@ -46,6 +46,7 @@ class RayOnSparkCPUClusterTestBase(ABC): @classmethod def teardown_class(cls): time.sleep(10) # Wait all background spark job canceled. + os.environ.pop("SPARK_WORKER_CORES", None) cls.spark.stop() @staticmethod @@ -70,6 +71,11 @@ def test_cpu_allocation(self): self.num_cpus_per_spark_task * 2, MAX_NUM_WORKER_NODES, ), + ( + self.max_spark_tasks // 2, + self.num_cpus_per_spark_task * 2, + self.max_spark_tasks // 2 + 1, + ), # Test case: requesting resources exceeding all cluster resources ]: with _setup_ray_cluster( num_worker_nodes=num_worker_nodes_arg, @@ -92,6 +98,12 @@ def test_public_api(self): ray_temp_root_dir=ray_temp_root_dir, head_node_options={"include_dashboard": True}, ) + + assert ( + os.environ["RAY_ADDRESS"] + == ray.util.spark.cluster_init._active_ray_cluster.address + ) + ray.init() @ray.remote @@ -104,6 +116,8 @@ def f(x): shutdown_ray_cluster() + assert "RAY_ADDRESS" not in os.environ + time.sleep(7) # assert temp dir is removed. assert len(os.listdir(ray_temp_root_dir)) == 1 and os.listdir( @@ -121,7 +135,7 @@ def f(x): if ray.util.spark.cluster_init._active_ray_cluster is not None: # if the test raised error and does not destroy cluster, # destroy it here. - ray.util.spark._active_ray_cluster.shutdown() + ray.util.spark.cluster_init._active_ray_cluster.shutdown() time.sleep(5) shutil.rmtree(ray_temp_root_dir, ignore_errors=True) shutil.rmtree(collect_log_to_path, ignore_errors=True) diff --git a/python/ray/tests/spark/test_databricks_hook.py b/python/ray/tests/spark/test_databricks_hook.py new file mode 100644 index 000000000000..5d296ac4f804 --- /dev/null +++ b/python/ray/tests/spark/test_databricks_hook.py @@ -0,0 +1,80 @@ +import sys + +import pytest +import os +import time +import ray +from pyspark.sql import SparkSession +from ray.util.spark import setup_ray_cluster +import ray.util.spark.databricks_hook + + +pytestmark = pytest.mark.skipif( + not sys.platform.startswith("linux"), + reason="Ray on spark only supports running on Linux.", +) + + +class MockDbApiEntry: + def __init__(self): + self.created_time = time.time() + self.registered_job_groups = [] + + def getIdleTimeMillisSinceLastNotebookExecution(self): + return (time.time() - self.created_time) * 1000 + + def registerBackgroundSparkJobGroup(self, job_group_id): + self.registered_job_groups.append(job_group_id) + + +class TestDatabricksHook: + @classmethod + def setup_class(cls): + os.environ["SPARK_WORKER_CORES"] = "2" + cls.spark = ( + SparkSession.builder.master("local-cluster[1, 2, 1024]") + .config("spark.task.cpus", "1") + .config("spark.task.maxFailures", "1") + .config("spark.executorEnv.RAY_ON_SPARK_WORKER_CPU_CORES", "2") + .getOrCreate() + ) + + @classmethod + def teardown_class(cls): + time.sleep(10) # Wait all background spark job canceled. + cls.spark.stop() + os.environ.pop("SPARK_WORKER_CORES") + + def test_hook(self, monkeypatch): + monkeypatch.setattr( + "ray.util.spark.databricks_hook._DATABRICKS_DEFAULT_TMP_DIR", "/tmp" + ) + monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "12.2") + monkeypatch.setenv("DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES", "0.5") + db_api_entry = MockDbApiEntry() + monkeypatch.setattr( + "ray.util.spark.databricks_hook._get_db_api_entry", lambda: db_api_entry + ) + try: + setup_ray_cluster( + num_worker_nodes=2, + head_node_options={"include_dashboard": False}, + ) + cluster = ray.util.spark.cluster_init._active_ray_cluster + assert not cluster.is_shutdown + assert db_api_entry.registered_job_groups == [cluster.spark_job_group_id] + time.sleep(35) + assert cluster.is_shutdown + assert ray.util.spark.cluster_init._active_ray_cluster is None + finally: + if ray.util.spark.cluster_init._active_ray_cluster is not None: + # if the test raised error and does not destroy cluster, + # destroy it here. + ray.util.spark.cluster_init._active_ray_cluster.shutdown() + + +if __name__ == "__main__": + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/util/spark/cluster_init.py b/python/ray/util/spark/cluster_init.py index c8f8b347f4bc..50c5ba4fa21f 100644 --- a/python/ray/util/spark/cluster_init.py +++ b/python/ray/util/spark/cluster_init.py @@ -100,11 +100,6 @@ def _cancel_background_spark_job(self): def wait_until_ready(self): import ray - if self.background_job_exception is not None: - raise RuntimeError( - "Ray workers has exited." - ) from self.background_job_exception - if self.is_shutdown: raise RuntimeError( "The ray cluster has been shut down or it failed to start." @@ -125,6 +120,16 @@ def wait_until_ready(self): last_progress_move_time = time.time() while True: time.sleep(_RAY_CLUSTER_STARTUP_PROGRESS_CHECKING_INTERVAL) + + # Inside the waiting ready loop, + # checking `self.background_job_exception`, if it is not None, + # it means the background spark job has failed, + # in this case, raise error directly. + if self.background_job_exception is not None: + raise RuntimeError( + "Ray workers have exited." + ) from self.background_job_exception + cur_alive_worker_count = ( len([node for node in ray.nodes() if node["Alive"]]) - 1 ) # Minus 1 means excluding the head node. @@ -144,12 +149,17 @@ def wait_until_ready(self): time.time() - last_progress_move_time > _RAY_CONNECT_CLUSTER_POLL_PROGRESS_TIMEOUT ): + if cur_alive_worker_count == 0: + raise RuntimeError( + "Current spark cluster has no resources to launch " + "Ray worker nodes." + ) _logger.warning( "Timeout in waiting for all ray workers to start. " "Started / Total requested: " f"({cur_alive_worker_count} / {self.num_worker_nodes}). " - "Please check ray logs to see why some ray workers " - "failed to start." + "Current spark cluster does not have sufficient resources " + "to launch requested number of Ray worker nodes." ) return finally: @@ -678,7 +688,7 @@ def background_job_thread_fn(): ).start() # Call hook immediately after spark job started. - start_hook.on_spark_background_job_created(spark_job_group_id) + start_hook.on_cluster_created(ray_cluster_handler) # wait background spark task starting. for _ in range(_BACKGROUND_JOB_STARTUP_WAIT): @@ -699,6 +709,7 @@ def background_job_thread_fn(): _active_ray_cluster = None +_active_ray_cluster_rwlock = threading.RLock() def _create_resource_profile(num_cpus_per_node, num_gpus_per_node): @@ -1007,23 +1018,24 @@ def setup_ray_cluster( else: _logger.warning("\n".join(insufficient_resources)) - cluster = _setup_ray_cluster( - num_worker_nodes=num_worker_nodes, - num_cpus_per_node=num_cpus_per_node, - num_gpus_per_node=num_gpus_per_node, - using_stage_scheduling=using_stage_scheduling, - heap_memory_per_node=ray_worker_node_heap_mem_bytes, - object_store_memory_per_node=ray_worker_node_object_store_mem_bytes, - head_node_options=head_node_options, - worker_node_options=worker_node_options, - ray_temp_root_dir=ray_temp_root_dir, - collect_log_to_path=collect_log_to_path, - ) - cluster.wait_until_ready() # NB: this line might raise error. + with _active_ray_cluster_rwlock: + cluster = _setup_ray_cluster( + num_worker_nodes=num_worker_nodes, + num_cpus_per_node=num_cpus_per_node, + num_gpus_per_node=num_gpus_per_node, + using_stage_scheduling=using_stage_scheduling, + heap_memory_per_node=ray_worker_node_heap_mem_bytes, + object_store_memory_per_node=ray_worker_node_object_store_mem_bytes, + head_node_options=head_node_options, + worker_node_options=worker_node_options, + ray_temp_root_dir=ray_temp_root_dir, + collect_log_to_path=collect_log_to_path, + ) + cluster.wait_until_ready() # NB: this line might raise error. - # If connect cluster successfully, set global _active_ray_cluster to be the started - # cluster. - _active_ray_cluster = cluster + # If connect cluster successfully, set global _active_ray_cluster to be the + # started cluster. + _active_ray_cluster = cluster return cluster.address @@ -1033,8 +1045,10 @@ def shutdown_ray_cluster() -> None: Shut down the active ray cluster. """ global _active_ray_cluster - if _active_ray_cluster is None: - raise RuntimeError("No active ray cluster to shut down.") - _active_ray_cluster.shutdown() - _active_ray_cluster = None + with _active_ray_cluster_rwlock: + if _active_ray_cluster is None: + raise RuntimeError("No active ray cluster to shut down.") + + _active_ray_cluster.shutdown() + _active_ray_cluster = None diff --git a/python/ray/util/spark/databricks_hook.py b/python/ray/util/spark/databricks_hook.py index b55c5b1b9db6..0d45baef499c 100644 --- a/python/ray/util/spark/databricks_hook.py +++ b/python/ray/util/spark/databricks_hook.py @@ -1,6 +1,10 @@ +import os + from .start_hook_base import RayOnSparkStartHook from .utils import get_spark_session import logging +import threading +import time _logger = logging.getLogger(__name__) @@ -54,21 +58,104 @@ def display_databricks_driver_proxy_url(spark_context, port, title): ) +DATABRICKS_AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS = 3 +DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES = ( + "DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES" +) + + +def _get_db_api_entry(): + """ + Get databricks API entry point. + """ + return get_dbutils().entry_point + + +_DATABRICKS_DEFAULT_TMP_DIR = "/local_disk0/tmp" + + class DefaultDatabricksRayOnSparkStartHook(RayOnSparkStartHook): def get_default_temp_dir(self): - return "/local_disk0/tmp" + return _DATABRICKS_DEFAULT_TMP_DIR def on_ray_dashboard_created(self, port): display_databricks_driver_proxy_url( get_spark_session().sparkContext, port, "Ray Cluster Dashboard" ) - def on_spark_background_job_created(self, job_group_id): + def on_cluster_created(self, ray_cluster_handler): + db_api_entry = _get_db_api_entry() try: - get_dbutils().entry_point.registerBackgroundSparkJobGroup(job_group_id) + db_api_entry.registerBackgroundSparkJobGroup( + ray_cluster_handler.spark_job_group_id + ) except Exception: _logger.warning( - "Register ray cluster spark job as background job failed. You need to " - "manually call `ray_cluster_on_spark.shutdown()` before detaching " - "your databricks python REPL." + "Registering Ray cluster spark job as background job failed. " + "You need to manually call `ray.util.spark.shutdown_ray_cluster()` " + "before detaching your databricks notebook." + ) + + auto_shutdown_minutes = float( + os.environ.get(DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES, "30") + ) + if auto_shutdown_minutes == 0: + _logger.info( + "The Ray cluster will keep running until you manually detach the " + "databricks notebook or call " + "`ray.util.spark.shutdown_ray_cluster()`." + ) + return + if auto_shutdown_minutes < 0: + raise ValueError( + "You must set " + f"'{DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES}' " + "to a value >= 0." ) + + try: + db_api_entry.getIdleTimeMillisSinceLastNotebookExecution() + except Exception: + _logger.warning( + "Databricks `getIdleTimeMillisSinceLastNotebookExecution` API " + "is unavailable, it is probably because that " + "your current Databricks Runtime version does not support API " + "`getIdleTimeMillisSinceLastNotebookExecution`, we cannot " + "automatically shut down Ray cluster when databricks notebook " + "is inactive, you need to manually detach databricks notebook " + "or call `ray.util.spark.shutdown_ray_cluster()` to shut down " + "Ray cluster on spark." + ) + return + + _logger.info( + "The Ray cluster will be shut down automatically if you don't run " + "commands on the databricks notebook for " + f"{auto_shutdown_minutes} minutes. You can change the " + "automatically shutdown minutes by setting " + f"'{DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_MINUTES}' environment " + "variable, setting it to 0 means that the Ray cluster keeps running " + "until you manually call `ray.util.spark.shutdown_ray_cluster()` or " + "detach databricks notebook." + ) + + def auto_shutdown_watcher(): + auto_shutdown_millis = auto_shutdown_minutes * 60 * 1000 + while True: + if ray_cluster_handler.is_shutdown: + # The cluster is shut down. The watcher thread exits. + return + + idle_time = db_api_entry.getIdleTimeMillisSinceLastNotebookExecution() + + if idle_time > auto_shutdown_millis: + from ray.util.spark import cluster_init + + with cluster_init._active_ray_cluster_rwlock: + if ray_cluster_handler is cluster_init._active_ray_cluster: + cluster_init.shutdown_ray_cluster() + return + + time.sleep(DATABRICKS_AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS) + + threading.Thread(target=auto_shutdown_watcher, daemon=True).start() diff --git a/python/ray/util/spark/start_hook_base.py b/python/ray/util/spark/start_hook_base.py index 6421c0fc3c58..5bb750d3bdfa 100644 --- a/python/ray/util/spark/start_hook_base.py +++ b/python/ray/util/spark/start_hook_base.py @@ -5,5 +5,5 @@ def get_default_temp_dir(self): def on_ray_dashboard_created(self, port): pass - def on_spark_background_job_created(self, job_group): + def on_cluster_created(self, ray_cluster_handler): pass diff --git a/python/requirements_test.txt b/python/requirements_test.txt index 8e994f34e43c..ab643c3709fe 100644 --- a/python/requirements_test.txt +++ b/python/requirements_test.txt @@ -55,7 +55,7 @@ pygame==2.1.2; python_version < '3.11' Pygments==2.13.0 pymongo==4.3.2 # TODO: Replace this with pyspark==3.4 once it is released. -https://ml-team-public-read.s3.us-west-2.amazonaws.com/pyspark-3.4.0.dev0.tar.gz +https://ml-team-public-read.s3.us-west-2.amazonaws.com/spark-pkgs/pyspark-3.4.0.dev0-0cb0fa313979e1b82ddd711a05d8c4e78cf6c9f5.tar.gz pytest==7.0.1 pytest-asyncio==0.16.0 pytest-rerunfailures==10.2