From f008a3fbc8bb1d173b9a2883d7fae09e81ae4fa6 Mon Sep 17 00:00:00 2001 From: Alan Guo Date: Wed, 5 Jun 2024 16:53:41 -0700 Subject: [PATCH] Add train_head.py with API to serve train runs (#45711) Signed-off-by: Alan Guo Signed-off-by: Richard Liu --- dashboard/modules/train/__init__.py | 0 dashboard/modules/train/train_head.py | 89 ++++++++++++++++++++++ python/ray/train/BUILD | 15 +++- python/ray/train/_internal/state/schema.py | 5 ++ python/ray/train/tests/test_train_head.py | 46 +++++++++++ 5 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 dashboard/modules/train/__init__.py create mode 100644 dashboard/modules/train/train_head.py create mode 100644 python/ray/train/tests/test_train_head.py diff --git a/dashboard/modules/train/__init__.py b/dashboard/modules/train/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/dashboard/modules/train/train_head.py b/dashboard/modules/train/train_head.py new file mode 100644 index 000000000000..1747fe0580f7 --- /dev/null +++ b/dashboard/modules/train/train_head.py @@ -0,0 +1,89 @@ +import logging + +from aiohttp.web import Request, Response +import ray +from ray.util.annotations import DeveloperAPI + +import ray.dashboard.optional_utils as dashboard_optional_utils +import ray.dashboard.utils as dashboard_utils + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +routes = dashboard_optional_utils.DashboardHeadRouteTable + + +class TrainHead(dashboard_utils.DashboardHeadModule): + def __init__(self, dashboard_head): + super().__init__(dashboard_head) + self._train_stats_actor = None + + @routes.get("/api/train/runs") + @dashboard_optional_utils.init_ray_and_catch_exceptions() + @DeveloperAPI + async def get_train_runs(self, req: Request) -> Response: + try: + from ray.train._internal.state.schema import ( + TrainRunsResponse, + ) + except ImportError: + logger.exception( + "Train is not installed. Please run `pip install ray[train]` " + "when setting up Ray on your cluster." + ) + return Response( + status=500, + text="Train is not installed. Please run `pip install ray[train]` " + "when setting up Ray on your cluster.", + ) + + stats_actor = await self.get_train_stats_actor() + + if stats_actor is None: + details = TrainRunsResponse(train_runs=[]) + else: + try: + train_runs = await stats_actor.get_all_train_runs.remote() + # TODO(aguo): Sort by created_at + details = TrainRunsResponse(train_runs=list(train_runs.values())) + except ray.exceptions.RayTaskError as e: + # Task failure sometimes are due to GCS + # failure. When GCS failed, we expect a longer time + # to recover. + return Response( + status=503, + text=( + "Failed to get a response from the train stats actor. " + f"The GCS may be down, please retry later: {e}" + ), + ) + + return Response( + text=details.json(), + content_type="application/json", + ) + + @staticmethod + def is_minimal_module(): + return False + + async def run(self, server): + pass + + async def get_train_stats_actor(self): + """ + Gets the train stats actor and caches it as an instance variable. + """ + try: + from ray.train._internal.state.state_actor import get_state_actor + + if self._train_stats_actor is None: + self._train_stats_actor = get_state_actor() + + return self._train_stats_actor + except ImportError: + logger.exception( + "Train is not installed. Please run `pip install ray[train]` " + "when setting up Ray on your cluster." + ) + return None diff --git a/python/ray/train/BUILD b/python/ray/train/BUILD index 4235d57b4ab2..65bad8397c9e 100644 --- a/python/ray/train/BUILD +++ b/python/ray/train/BUILD @@ -658,7 +658,7 @@ py_test( deps = [":train_lib", ":conftest"], ) -### E2E Data + Train +### E2E Data + Train py_test( name = "test_datasets_train", size = "medium", @@ -667,6 +667,19 @@ py_test( args = ["--smoke-test", "--num-workers=2", "--use-gpu"] ) +### Train Dashboard +py_test( + name = "test_train_head", + size = "small", + srcs = ["tests/test_train_head.py"], + tags = [ + "exclusive", + "ray_air", + "team:ml", + ], + deps = [":train_lib", ":conftest"], +) + # This is a dummy test dependency that causes the above tests to be # re-run if any of these files changes. py_library( diff --git a/python/ray/train/_internal/state/schema.py b/python/ray/train/_internal/state/schema.py index 63567bf5d6a7..426888e8bc5c 100644 --- a/python/ray/train/_internal/state/schema.py +++ b/python/ray/train/_internal/state/schema.py @@ -45,3 +45,8 @@ class TrainRunInfo(BaseModel): datasets: List[TrainDatasetInfo] = Field( description="A List of dataset info for this Train run." ) + + +@DeveloperAPI +class TrainRunsResponse(BaseModel): + train_runs: List[TrainRunInfo] diff --git a/python/ray/train/tests/test_train_head.py b/python/ray/train/tests/test_train_head.py new file mode 100644 index 000000000000..6cb1e13d4880 --- /dev/null +++ b/python/ray/train/tests/test_train_head.py @@ -0,0 +1,46 @@ +import os +import sys +import time + +import pytest +import requests + +import ray +from ray.train import RunConfig, ScalingConfig +from ray.train.torch import TorchTrainer + + +def test_get_train_runs(monkeypatch, shutdown_only): + monkeypatch.setenv("RAY_TRAIN_ENABLE_STATE_TRACKING", "1") + + ray.init(num_cpus=8) + + def train_func(): + print("Training Starts") + time.sleep(0.5) + + datasets = {"train": ray.data.range(100), "val": ray.data.range(100)} + + trainer = TorchTrainer( + train_func, + run_config=RunConfig(name="my_train_run", storage_path="/tmp/cluster_storage"), + scaling_config=ScalingConfig(num_workers=4, use_gpu=False), + datasets=datasets, + ) + trainer.fit() + + # Call the train run api + url = ray._private.worker.get_dashboard_url() + resp = requests.get("http://" + url + "/api/train/runs") + assert resp.status_code == 200 + body = resp.json() + assert len(body["train_runs"]) == 1 + assert body["train_runs"][0]["name"] == "my_train_run" + assert len(body["train_runs"][0]["workers"]) == 4 + + +if __name__ == "__main__": + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__]))