Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spark] Automatically shut down ray on spark cluster if user does not execute commands on databricks notebook for a long time #31962

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ py_test_module_list(
"spark/test_GPU.py",
"spark/test_multicores_per_task.py",
"spark/test_utils.py",
"spark/test_databricks_hook.py",
],
size = "large",
tags = ["exclusive", "spark_plugin_tests", "team:serverless"],
Expand Down
6 changes: 6 additions & 0 deletions python/ray/tests/spark/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def test_public_api(self):
ray_temp_root_dir=ray_temp_root_dir,
head_node_options={"include_dashboard": True},
)

assert os.environ["RAY_ADDRESS"] == \
ray.util.spark.cluster_init._active_ray_cluster.address

ray.init()

@ray.remote
Expand All @@ -104,6 +108,8 @@ def f(x):

shutdown_ray_cluster()

assert "RAY_ADDRESS" not in os.environ

time.sleep(7)
# assert temp dir is removed.
assert len(os.listdir(ray_temp_root_dir)) == 1 and os.listdir(
Expand Down
81 changes: 81 additions & 0 deletions python/ray/tests/spark/test_databricks_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import sys

import mock
import pytest
import os
import time
import ray
from pyspark.sql import SparkSession
from ray.util.spark import setup_ray_cluster


pytestmark = pytest.mark.skipif(
not sys.platform.startswith("linux"),
reason="Ray on spark only supports running on Linux.",
)


class MockDbApiEntry:

def __init__(self):
self.idle_time = 0
self.registered_job_groups = []

def getIdleTimeMillisSinceLastNotebookExecution(self):
return (time.time() - self.idle_time) * 1000

def registerBackgroundSparkJobGroup(self, job_group_id):
self.registered_job_groups.append(job_group_id)


class TestDatabricksHook:
@classmethod
def setup_class(cls):
os.environ["SPARK_WORKER_CORES"] = "2"
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
cls.spark = (
SparkSession.builder.master("local-cluster[1, 2, 1024]")
.config("spark.task.cpus", "1")
.config("spark.task.maxFailures", "1")
.config("spark.executorEnv.RAY_ON_SPARK_WORKER_CPU_CORES", "2")
.getOrCreate()
)

@classmethod
def teardown_class(cls):
time.sleep(10) # Wait all background spark job canceled.
cls.spark.stop()

def test_hook(self):
try:
db_api_entry = MockDbApiEntry()
with mock.patch(
"ray.util.spark.databricks_hook._get_db_api_entry",
return_value=db_api_entry,
), mock.patch.dict(
os.environ,
{'DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_TIMEOUT_MINUTES': "2"},
):
setup_ray_cluster(
num_worker_nodes=2,
head_node_options={"include_dashboard": False},
)
cluster = ray.util.spark.cluster_init._active_ray_cluster
assert not cluster.is_shutdown
assert db_api_entry.registered_job_groups == \
[cluster.spark_job_group_id]
time.sleep(2.5)
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
assert cluster.is_shutdown
assert ray.util.spark.cluster_init._active_ray_cluster is None
finally:
if ray.util.spark.cluster_init._active_ray_cluster is not None:
# if the test raised error and does not destroy cluster,
# destroy it here.
ray.util.spark._active_ray_cluster.shutdown()
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
time.sleep(5)
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))
55 changes: 32 additions & 23 deletions python/ray/util/spark/cluster_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,17 @@ def wait_until_ready(self):
time.time() - last_progress_move_time
> _RAY_CONNECT_CLUSTER_POLL_PROGRESS_TIMEOUT
):
if cur_alive_worker_count == 0:
raise RuntimeError(
"Current spark cluster has no resources to launch "
"Ray worker nodes."
)
_logger.warning(
"Timeout in waiting for all ray workers to start. "
"Started / Total requested: "
f"({cur_alive_worker_count} / {self.num_worker_nodes}). "
"Please check ray logs to see why some ray workers "
"failed to start."
"Current spark cluster does not have sufficient resources "
"to launch requested number Ray worker nodes."
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
)
return
finally:
Expand Down Expand Up @@ -678,7 +683,7 @@ def background_job_thread_fn():
).start()

# Call hook immediately after spark job started.
start_hook.on_spark_background_job_created(spark_job_group_id)
start_hook.on_cluster_created(ray_cluster_handler)

# wait background spark task starting.
for _ in range(_BACKGROUND_JOB_STARTUP_WAIT):
Expand All @@ -699,6 +704,7 @@ def background_job_thread_fn():


_active_ray_cluster = None
_active_ray_cluster_rwlock = threading.RLock()


def _create_resource_profile(num_cpus_per_node, num_gpus_per_node):
Expand Down Expand Up @@ -1007,23 +1013,24 @@ def setup_ray_cluster(
else:
_logger.warning("\n".join(insufficient_resources))

cluster = _setup_ray_cluster(
num_worker_nodes=num_worker_nodes,
num_cpus_per_node=num_cpus_per_node,
num_gpus_per_node=num_gpus_per_node,
using_stage_scheduling=using_stage_scheduling,
heap_memory_per_node=ray_worker_node_heap_mem_bytes,
object_store_memory_per_node=ray_worker_node_object_store_mem_bytes,
head_node_options=head_node_options,
worker_node_options=worker_node_options,
ray_temp_root_dir=ray_temp_root_dir,
collect_log_to_path=collect_log_to_path,
)
cluster.wait_until_ready() # NB: this line might raise error.
with _active_ray_cluster_rwlock:
cluster = _setup_ray_cluster(
num_worker_nodes=num_worker_nodes,
num_cpus_per_node=num_cpus_per_node,
num_gpus_per_node=num_gpus_per_node,
using_stage_scheduling=using_stage_scheduling,
heap_memory_per_node=ray_worker_node_heap_mem_bytes,
object_store_memory_per_node=ray_worker_node_object_store_mem_bytes,
head_node_options=head_node_options,
worker_node_options=worker_node_options,
ray_temp_root_dir=ray_temp_root_dir,
collect_log_to_path=collect_log_to_path,
)
cluster.wait_until_ready() # NB: this line might raise error.

# If connect cluster successfully, set global _active_ray_cluster to be the started
# cluster.
_active_ray_cluster = cluster
# If connect cluster successfully, set global _active_ray_cluster to be the
# started cluster.
_active_ray_cluster = cluster
return cluster.address


Expand All @@ -1033,8 +1040,10 @@ def shutdown_ray_cluster() -> None:
Shut down the active ray cluster.
"""
global _active_ray_cluster
if _active_ray_cluster is None:
raise RuntimeError("No active ray cluster to shut down.")

_active_ray_cluster.shutdown()
_active_ray_cluster = None
with _active_ray_cluster_rwlock:
if _active_ray_cluster is None:
raise RuntimeError("No active ray cluster to shut down.")

_active_ray_cluster.shutdown()
_active_ray_cluster = None
78 changes: 76 additions & 2 deletions python/ray/util/spark/databricks_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os

from .start_hook_base import RayOnSparkStartHook
from .utils import get_spark_session
import logging
import threading
import time

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,6 +58,19 @@ def display_databricks_driver_proxy_url(spark_context, port, title):
)


AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS = 3
DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_TIMEOUT_MINUTES = (
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
"DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_TIMEOUT_MINUTES"
)


def _get_db_api_entry():
"""
Get databricks API entry point.
"""
return get_dbutils().entry_point


class DefaultDatabricksRayOnSparkStartHook(RayOnSparkStartHook):
def get_default_temp_dir(self):
return "/local_disk0/tmp"
Expand All @@ -63,12 +80,69 @@ def on_ray_dashboard_created(self, port):
get_spark_session().sparkContext, port, "Ray Cluster Dashboard"
)

def on_spark_background_job_created(self, job_group_id):
def on_cluster_created(self, ray_cluster_handler):
db_api_entry = _get_db_api_entry()
try:
get_dbutils().entry_point.registerBackgroundSparkJobGroup(job_group_id)
db_api_entry.registerBackgroundSparkJobGroup(
ray_cluster_handler.spark_job_group_id
)
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
except Exception:
_logger.warning(
"Register ray cluster spark job as background job failed. You need to "
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
"manually call `ray_cluster_on_spark.shutdown()` before detaching "
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
"your databricks python REPL."
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
)

auto_shutdown_timeout_minutes = float(
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
os.environ.get(DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_TIMEOUT_MINUTES, "30")
)
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
if auto_shutdown_timeout_minutes == 0:
_logger.info(
"The Ray cluster will keep running until you manually detach the "
"databricks notebook or call "
"`ray.util.spark.shutdown_ray_cluster()`."
)
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
else:
_logger.info(
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
"The Ray cluster will be shut down automatically if you don't run "
"commands on the databricks notebook for "
f"{auto_shutdown_timeout_minutes} minutes. You can change the "
"timeout minutes by setting "
"'DATABRICKS_RAY_ON_SPARK_AUTOSHUTDOWN_TIMEOUT_MINUTES' environment "
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
"variable, setting it to 0 means infinite timeout."
)

def auto_shutdown_watcher():
auto_shutdown_timeout_millis = auto_shutdown_timeout_minutes * 60 * 1000
while True:
if ray_cluster_handler.is_shutdown:
# The cluster is shut down. The watcher thread exits.
return

try:
idle_time = (
db_api_entry.getIdleTimeMillisSinceLastNotebookExecution()
)
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
except Exception:
_logger.warning(
"Your current Databricks Runtime version does not support API "
"`getIdleTimeMillisSinceLastNotebookExecution`, we cannot "
"automatically shut down Ray cluster when databricks notebook "
"is inactive, you need to manually detach databricks notebook "
"or call `ray.util.spark.shutdown_ray_cluster()` to shut down "
"Ray cluster on spark."
)
return

if idle_time > auto_shutdown_timeout_millis:
from ray.util.spark import cluster_init

with cluster_init._active_ray_cluster_rwlock:
if ray_cluster_handler is cluster_init._active_ray_cluster:
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
cluster_init.shutdown_ray_cluster()
return

time.sleep(AUTO_SHUTDOWN_POLL_INTERVAL_SECONDS)

if auto_shutdown_timeout_minutes > 0:
threading.Thread(target=auto_shutdown_watcher, args=(), daemon=True).start()
2 changes: 1 addition & 1 deletion python/ray/util/spark/start_hook_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ def get_default_temp_dir(self):
def on_ray_dashboard_created(self, port):
pass

def on_spark_background_job_created(self, job_group):
def on_cluster_created(self, ray_cluster_handler):
pass
2 changes: 1 addition & 1 deletion python/requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pygame==2.1.2; python_version < '3.11'
Pygments==2.13.0
pymongo==4.3.2
# TODO: Replace this with pyspark==3.4 once it is released.
https://ml-team-public-read.s3.us-west-2.amazonaws.com/pyspark-3.4.0.dev0.tar.gz
https://ml-team-public-read.s3.us-west-2.amazonaws.com/spark-pkgs/pyspark-3.4.0.dev0-0cb0fa313979e1b82ddd711a05d8c4e78cf6c9f5.tar.gz
pytest==7.0.1
pytest-asyncio==0.16.0
pytest-rerunfailures==10.2
Expand Down