Skip to content

Commit

Permalink
Add train_head.py with API to serve train runs (#45711)
Browse files Browse the repository at this point in the history

Signed-off-by: Alan Guo <[email protected]>
  • Loading branch information
alanwguo authored Jun 5, 2024
1 parent e233555 commit dd15641
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 1 deletion.
Empty file.
89 changes: 89 additions & 0 deletions dashboard/modules/train/train_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import logging

from aiohttp.web import Request, Response
import ray
from ray.util.annotations import DeveloperAPI

import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.dashboard.utils as dashboard_utils

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

routes = dashboard_optional_utils.DashboardHeadRouteTable


class TrainHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._train_stats_actor = None

@routes.get("/api/train/runs")
@dashboard_optional_utils.init_ray_and_catch_exceptions()
@DeveloperAPI
async def get_train_runs(self, req: Request) -> Response:
try:
from ray.train._internal.state.schema import (
TrainRunsResponse,
)
except ImportError:
logger.exception(
"Train is not installed. Please run `pip install ray[train]` "
"when setting up Ray on your cluster."
)
return Response(
status=500,
text="Train is not installed. Please run `pip install ray[train]` "
"when setting up Ray on your cluster.",
)

stats_actor = await self.get_train_stats_actor()

if stats_actor is None:
details = TrainRunsResponse(train_runs=[])
else:
try:
train_runs = await stats_actor.get_all_train_runs.remote()
# TODO(aguo): Sort by created_at
details = TrainRunsResponse(train_runs=list(train_runs.values()))
except ray.exceptions.RayTaskError as e:
# Task failure sometimes are due to GCS
# failure. When GCS failed, we expect a longer time
# to recover.
return Response(
status=503,
text=(
"Failed to get a response from the train stats actor. "
f"The GCS may be down, please retry later: {e}"
),
)

return Response(
text=details.json(),
content_type="application/json",
)

@staticmethod
def is_minimal_module():
return False

async def run(self, server):
pass

async def get_train_stats_actor(self):
"""
Gets the train stats actor and caches it as an instance variable.
"""
try:
from ray.train._internal.state.state_actor import get_state_actor

if self._train_stats_actor is None:
self._train_stats_actor = get_state_actor()

return self._train_stats_actor
except ImportError:
logger.exception(
"Train is not installed. Please run `pip install ray[train]` "
"when setting up Ray on your cluster."
)
return None
15 changes: 14 additions & 1 deletion python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ py_test(
deps = [":train_lib", ":conftest"],
)

### E2E Data + Train
### E2E Data + Train
py_test(
name = "test_datasets_train",
size = "medium",
Expand All @@ -667,6 +667,19 @@ py_test(
args = ["--smoke-test", "--num-workers=2", "--use-gpu"]
)

### Train Dashboard
py_test(
name = "test_train_head",
size = "small",
srcs = ["tests/test_train_head.py"],
tags = [
"exclusive",
"ray_air",
"team:ml",
],
deps = [":train_lib", ":conftest"],
)

# This is a dummy test dependency that causes the above tests to be
# re-run if any of these files changes.
py_library(
Expand Down
5 changes: 5 additions & 0 deletions python/ray/train/_internal/state/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ class TrainRunInfo(BaseModel):
datasets: List[TrainDatasetInfo] = Field(
description="A List of dataset info for this Train run."
)


@DeveloperAPI
class TrainRunsResponse(BaseModel):
train_runs: List[TrainRunInfo]
46 changes: 46 additions & 0 deletions python/ray/train/tests/test_train_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import sys
import time

import pytest
import requests

import ray
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer


def test_get_train_runs(monkeypatch, shutdown_only):
monkeypatch.setenv("RAY_TRAIN_ENABLE_STATE_TRACKING", "1")

ray.init(num_cpus=8)

def train_func():
print("Training Starts")
time.sleep(0.5)

datasets = {"train": ray.data.range(100), "val": ray.data.range(100)}

trainer = TorchTrainer(
train_func,
run_config=RunConfig(name="my_train_run", storage_path="/tmp/cluster_storage"),
scaling_config=ScalingConfig(num_workers=4, use_gpu=False),
datasets=datasets,
)
trainer.fit()

# Call the train run api
url = ray._private.worker.get_dashboard_url()
resp = requests.get("http://" + url + "/api/train/runs")
assert resp.status_code == 200
body = resp.json()
assert len(body["train_runs"]) == 1
assert body["train_runs"][0]["name"] == "my_train_run"
assert len(body["train_runs"][0]["workers"]) == 4


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))

0 comments on commit dd15641

Please sign in to comment.