diff --git a/python/ray/train/BUILD b/python/ray/train/BUILD index 6820db6e43c5..4235d57b4ab2 100644 --- a/python/ray/train/BUILD +++ b/python/ray/train/BUILD @@ -501,6 +501,14 @@ py_test( deps = [":train_lib", ":conftest"] ) +py_test( + name = "test_state", + size = "small", + srcs = ["tests/test_state.py"], + tags = ["team:ml", "exclusive"], + deps = [":train_lib", ":conftest"] +) + py_test( name = "test_tensorflow_checkpoint", size = "small", diff --git a/python/ray/train/_internal/backend_executor.py b/python/ray/train/_internal/backend_executor.py index cf6210c1bf4d..30ef5c657b3e 100644 --- a/python/ray/train/_internal/backend_executor.py +++ b/python/ray/train/_internal/backend_executor.py @@ -25,6 +25,7 @@ ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV, + RAY_TRAIN_ENABLE_STATE_TRACKING, TRAIN_ENABLE_WORKER_SPREAD_ENV, TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, ) @@ -118,6 +119,8 @@ def __init__( ) ] + self.state_tracking_enabled = env_integer(RAY_TRAIN_ENABLE_STATE_TRACKING, 0) + def start( self, initialization_hook: Optional[Callable[[], None]] = None, @@ -194,6 +197,12 @@ def _set_driver_dataset_context(ctx: DataContext): self._increment_failures() self._restart() + if self.state_tracking_enabled: + from ray.train._internal.state import TrainRunStateManager + from ray.train._internal.state.state_actor import get_state_actor + + self.state_manager = TrainRunStateManager(state_actor=get_state_actor()) + def _create_placement_group(self): """Creates a placement group if it does not exist. @@ -432,7 +441,6 @@ def start_training( data_config: DataConfig, storage: StorageContext, checkpoint: Optional[Checkpoint] = None, - on_session_init: Callable[[], None] = None, ) -> None: """Executes a training function on all workers in a separate thread. @@ -528,8 +536,18 @@ def initialize_session( self.get_with_failure_handling(futures) - if on_session_init: - on_session_init() + # Register Train Run before training starts + if self.state_tracking_enabled: + core_context = ray.runtime_context.get_runtime_context() + + self.state_manager.register_train_run( + run_id=self._trial_info.run_id, + run_name=self._trial_info.experiment_name, + job_id=core_context.get_job_id(), + controller_actor_id=core_context.get_actor_id(), + datasets=datasets, + worker_group=self.worker_group, + ) # Run the training function asynchronously in its own thread. def train_async(): diff --git a/python/ray/train/_internal/session.py b/python/ray/train/_internal/session.py index 8201e1f59dfa..8ad87c65314d 100644 --- a/python/ray/train/_internal/session.py +++ b/python/ray/train/_internal/session.py @@ -61,6 +61,7 @@ class TrialInfo: logdir: str driver_ip: str experiment_name: Optional[str] = None + run_id: Optional[str] = None class _FutureTrainingResult: @@ -461,6 +462,10 @@ def trial_name(self) -> str: def trial_id(self) -> str: return self.trial_info.id + @property + def run_id(self) -> str: + return self.trial_info.run_id + @property def trial_resources(self) -> "PlacementGroupFactory": return self.trial_info.resources @@ -818,6 +823,13 @@ def get_trial_id() -> str: return _get_session().trial_id +@PublicAPI(stability="alpha") +@_warn_session_misuse() +def get_run_id() -> str: + """Unique Train Run id for the corresponding trial.""" + return _get_session().run_id + + @PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_resources() -> "PlacementGroupFactory": diff --git a/python/ray/train/_internal/state/__init__.py b/python/ray/train/_internal/state/__init__.py new file mode 100644 index 000000000000..604a4fa39329 --- /dev/null +++ b/python/ray/train/_internal/state/__init__.py @@ -0,0 +1,14 @@ +from ray.train._internal.state.state_manager import TrainRunStateManager + +try: + import pydantic # noqa: F401 +except ImportError: + raise ModuleNotFoundError( + "pydantic isn't installed." + "To install pydantic, please run 'pip install pydantic'" + ) + + +__all__ = [ + "TrainRunStateManager", +] diff --git a/python/ray/train/_internal/state/schema.py b/python/ray/train/_internal/state/schema.py new file mode 100644 index 000000000000..63567bf5d6a7 --- /dev/null +++ b/python/ray/train/_internal/state/schema.py @@ -0,0 +1,47 @@ +from typing import List, Optional + +from ray._private.pydantic_compat import BaseModel, Field +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class TrainWorkerInfo(BaseModel): + """Metadata of a Ray Train worker.""" + + actor_id: str = Field(description="Actor ID of the worker.") + world_rank: int = Field(description="World rank of the worker.") + local_rank: int = Field(description="Local rank of the worker.") + node_rank: int = Field(description="Node rank of the worker.") + node_id: str = Field(description="ID of the node that the worker is running on.") + node_ip: str = Field( + description="IP address of the node that the worker is running on." + ) + pid: int = Field(description="Process ID of the worker.") + gpu_ids: List[int] = Field( + description="A list of GPU ids allocated to that worker." + ) + + +@DeveloperAPI +class TrainDatasetInfo(BaseModel): + name: str = Field( + description="The key of the dataset dict specified in Ray Train Trainer." + ) + dataset_uuid: str = Field(description="The uuid of the dataset.") + dataset_name: Optional[str] = Field(description="The name of the dataset.") + + +@DeveloperAPI +class TrainRunInfo(BaseModel): + """Metadata for a Ray Train run and information about its workers.""" + + name: str = Field(description="The name of the Train run.") + id: str = Field(description="The unique identifier for each Train run.") + job_id: str = Field(description="The Ray Job ID.") + controller_actor_id: str = Field(description="Actor Id of the Train controller.") + workers: List[TrainWorkerInfo] = Field( + description="A List of Train workers sorted by global ranks." + ) + datasets: List[TrainDatasetInfo] = Field( + description="A List of dataset info for this Train run." + ) diff --git a/python/ray/train/_internal/state/state_actor.py b/python/ray/train/_internal/state/state_actor.py new file mode 100644 index 000000000000..7efa7aa9a578 --- /dev/null +++ b/python/ray/train/_internal/state/state_actor.py @@ -0,0 +1,60 @@ +import logging +import threading +from typing import Dict, Optional + +import ray +from ray.actor import ActorHandle +from ray.train._internal.state.schema import TrainRunInfo + +logger = logging.getLogger(__name__) + + +@ray.remote(num_cpus=0) +class TrainStateActor: + def __init__(self): + self._run_infos: Dict[str, TrainRunInfo] = {} + + def register_train_run(self, run_info: TrainRunInfo) -> None: + # Register a new train run. + self._run_infos[run_info.id] = run_info + + def get_train_run(self, run_id: str) -> Optional[TrainRunInfo]: + # Retrieve a registered run with its id + return self._run_infos.get(run_id, None) + + def get_all_train_runs(self) -> Dict[str, TrainRunInfo]: + # Retrieve all registered train runs + return self._run_infos + + +TRAIN_STATE_ACTOR_NAME = "train_state_actor" +TRAIN_STATE_ACTOR_NAMESPACE = "_train_state_actor" + +_state_actor_lock: threading.RLock = threading.RLock() + + +def get_or_create_state_actor() -> ActorHandle: + """Get or create a `TrainStateActor` on the head node.""" + with _state_actor_lock: + state_actor = TrainStateActor.options( + name=TRAIN_STATE_ACTOR_NAME, + namespace=TRAIN_STATE_ACTOR_NAMESPACE, + get_if_exists=True, + lifetime="detached", + resources={"node:__internal_head__": 0.001}, + ).remote() + + # Ensure the state actor is ready + ray.get(state_actor.__ray_ready__.remote()) + return state_actor + + +def get_state_actor() -> Optional[ActorHandle]: + """Get the `TrainStateActor` if exists, otherwise return None.""" + try: + return ray.get_actor( + name=TRAIN_STATE_ACTOR_NAME, + namespace=TRAIN_STATE_ACTOR_NAMESPACE, + ) + except ValueError: + return None diff --git a/python/ray/train/_internal/state/state_manager.py b/python/ray/train/_internal/state/state_manager.py new file mode 100644 index 000000000000..66d08dfcae64 --- /dev/null +++ b/python/ray/train/_internal/state/state_manager.py @@ -0,0 +1,93 @@ +import logging +import os +from typing import Dict + +import ray +from ray.data import Dataset +from ray.train._internal.state.schema import ( + TrainDatasetInfo, + TrainRunInfo, + TrainWorkerInfo, +) +from ray.train._internal.utils import check_for_failure +from ray.train._internal.worker_group import WorkerGroup + +logger = logging.getLogger(__name__) + + +class TrainRunStateManager: + """A class that aggregates and reports train run info to TrainStateActor. + + This manager class is created on the train controller layer for each run. + """ + + def __init__(self, state_actor) -> None: + self.state_actor = state_actor + + def register_train_run( + self, + run_id: str, + job_id: str, + run_name: str, + controller_actor_id: str, + datasets: Dict[str, Dataset], + worker_group: WorkerGroup, + ) -> None: + """Collect Train Run Info and report to StateActor.""" + + if not self.state_actor: + logger.warning( + "Unable to register train run since `TrainStateActor` is not started." + ) + return + + def collect_train_worker_info(): + train_context = ray.train.get_context() + core_context = ray.runtime_context.get_runtime_context() + + return TrainWorkerInfo( + world_rank=train_context.get_world_rank(), + local_rank=train_context.get_local_rank(), + node_rank=train_context.get_node_rank(), + actor_id=core_context.get_actor_id(), + node_id=core_context.get_node_id(), + node_ip=ray.util.get_node_ip_address(), + gpu_ids=ray.get_gpu_ids(), + pid=os.getpid(), + ) + + futures = [ + worker_group.execute_single_async(index, collect_train_worker_info) + for index in range(len(worker_group)) + ] + success, exception = check_for_failure(futures) + + if not success: + logger.error( + "Failed to collect run information from the Ray Train " + f"workers:\n{exception}" + ) + return + + worker_info_list = ray.get(futures) + worker_info_list = sorted(worker_info_list, key=lambda info: info.world_rank) + + dataset_info_list = [ + TrainDatasetInfo( + name=ds_name, + dataset_name=ds._plan._dataset_name, + dataset_uuid=ds._plan._dataset_uuid, + ) + for ds_name, ds in datasets.items() + ] + + train_run_info = TrainRunInfo( + id=run_id, + job_id=job_id, + name=run_name, + controller_actor_id=controller_actor_id, + workers=worker_info_list, + datasets=dataset_info_list, + ) + + ray.get(self.state_actor.register_train_run.remote(train_run_info)) diff --git a/python/ray/train/constants.py b/python/ray/train/constants.py index f9ab73acc32d..530de9d3a2d5 100644 --- a/python/ray/train/constants.py +++ b/python/ray/train/constants.py @@ -78,6 +78,10 @@ def _get_ray_train_session_dir() -> str: # Defaults to 0, which always retries on node preemption failures. RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE = "RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE" +# Set this to 1 to start a StateActor and collect information Train Runs +# Defaults to 0 +RAY_TRAIN_ENABLE_STATE_TRACKING = "RAY_TRAIN_ENABLE_STATE_TRACKING" + # NOTE: When adding a new environment variable, please track it in this list. TRAIN_ENV_VARS = { ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, @@ -87,6 +91,7 @@ def _get_ray_train_session_dir() -> str: TRAIN_ENABLE_WORKER_SPREAD_ENV, RAY_CHDIR_TO_TRIAL_DIR, RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE, + RAY_TRAIN_ENABLE_STATE_TRACKING, } # Key for AIR Checkpoint metadata in TrainingResult metadata diff --git a/python/ray/train/data_parallel_trainer.py b/python/ray/train/data_parallel_trainer.py index 182ac8e84da1..0c47e4409553 100644 --- a/python/ray/train/data_parallel_trainer.py +++ b/python/ray/train/data_parallel_trainer.py @@ -1,7 +1,9 @@ import logging +import uuid from typing import Any, Callable, Dict, List, Optional, Type, Union import ray +from ray._private.ray_constants import env_integer from ray._private.thirdparty.tabulate.tabulate import tabulate from ray.air.config import RunConfig, ScalingConfig from ray.train import BackendConfig, Checkpoint, TrainingIterator @@ -10,6 +12,7 @@ from ray.train._internal.data_config import DataConfig from ray.train._internal.session import _TrainingResult, get_session from ray.train._internal.utils import construct_train_func, count_required_parameters +from ray.train.constants import RAY_TRAIN_ENABLE_STATE_TRACKING from ray.train.trainer import BaseTrainer, GenDataset from ray.util.annotations import DeveloperAPI, PublicAPI from ray.widgets import Template @@ -265,6 +268,11 @@ def __init__( train_total_resources.get("GPU", 0), ) + if env_integer(RAY_TRAIN_ENABLE_STATE_TRACKING, 0): + from ray.train._internal.state.state_actor import get_or_create_state_actor + + get_or_create_state_actor() + @PublicAPI(stability="beta") @classmethod def restore( @@ -435,6 +443,7 @@ def training_loop(self) -> None: logdir=session.get_trial_dir(), driver_ip=ray.util.get_node_ip_address(), experiment_name=session.get_experiment_name(), + run_id=uuid.uuid4().hex, ) backend_executor = self._backend_executor_cls( diff --git a/python/ray/train/tests/test_state.py b/python/ray/train/tests/test_state.py new file mode 100644 index 000000000000..9f1a2e81c46b --- /dev/null +++ b/python/ray/train/tests/test_state.py @@ -0,0 +1,276 @@ +import json +import os + +import pytest + +import ray +from ray.cluster_utils import Cluster +from ray.train import RunConfig, ScalingConfig +from ray.train._internal.state.schema import ( + TrainDatasetInfo, + TrainRunInfo, + TrainWorkerInfo, +) +from ray.train._internal.state.state_actor import ( + TRAIN_STATE_ACTOR_NAME, + TRAIN_STATE_ACTOR_NAMESPACE, + get_or_create_state_actor, +) +from ray.train._internal.state.state_manager import TrainRunStateManager +from ray.train._internal.worker_group import WorkerGroup +from ray.train.data_parallel_trainer import DataParallelTrainer + + +@pytest.fixture +def ray_start_gpu_cluster(): + cluster = Cluster() + cluster.add_node(num_gpus=8, num_cpus=9) + + ray.shutdown() + ray.init( + address=cluster.address, + runtime_env={"env_vars": {"RAY_TRAIN_ENABLE_STATE_TRACKING": "1"}}, + ignore_reinit_error=True, + ) + + yield + + ray.shutdown() + cluster.shutdown() + + +RUN_INFO_JSON_SAMPLE = """{ + "name": "default_run", + "id": "ad5256bc64c04c83833a8b006f531799", + "job_id": "0000000001", + "controller_actor_id": "3abd1972a19148d78acc78dd9414736e", + "workers": [ + { + "actor_id": "3d86c25634a71832dac32c8802000000", + "world_rank": 0, + "local_rank": 0, + "node_rank": 0, + "node_id": "b1e6cbed8533ae2def4e7e7ced9d19858ceb1ed8ab9ba81ab9c07825", + "node_ip": "10.0.208.100", + "pid": 76071, + "gpu_ids": [0] + }, + { + "actor_id": "8f162dd8365346d1b5c98ebd7338c4f9", + "world_rank": 1, + "local_rank": 1, + "node_rank": 0, + "node_id": "b1e6cbed8533ae2def4e7e7ced9d19858ceb1ed8ab9ba81ab9c07825", + "node_ip": "10.0.208.100", + "pid": 76072, + "gpu_ids": [1] + } + ], + "datasets": [ + { + "name": "train", + "dataset_name": "train_dataset", + "dataset_uuid": "1" + } + ] +}""" + + +def _get_run_info_sample(run_id=None, run_name=None) -> TrainRunInfo: + dataset_info = TrainDatasetInfo( + name="train", dataset_name="train_dataset", dataset_uuid="1" + ) + + worker_info_0 = TrainWorkerInfo( + actor_id="3d86c25634a71832dac32c8802000000", + world_rank=0, + local_rank=0, + node_rank=0, + node_id="b1e6cbed8533ae2def4e7e7ced9d19858ceb1ed8ab9ba81ab9c07825", + node_ip="10.0.208.100", + pid=76071, + gpu_ids=[0], + ) + + worker_info_1 = TrainWorkerInfo( + actor_id="8f162dd8365346d1b5c98ebd7338c4f9", + world_rank=1, + local_rank=1, + node_rank=0, + node_id="b1e6cbed8533ae2def4e7e7ced9d19858ceb1ed8ab9ba81ab9c07825", + node_ip="10.0.208.100", + pid=76072, + gpu_ids=[1], + ) + + run_info = TrainRunInfo( + name=run_name if run_name else "default_run", + id=run_id if run_id else "ad5256bc64c04c83833a8b006f531799", + job_id="0000000001", + controller_actor_id="3abd1972a19148d78acc78dd9414736e", + workers=[worker_info_0, worker_info_1], + datasets=[dataset_info], + ) + return run_info + + +def test_schema_equivalance(): + json_sample = RUN_INFO_JSON_SAMPLE + dict_sample = json.loads(RUN_INFO_JSON_SAMPLE) + + run_info_from_json = TrainRunInfo.parse_raw(json_sample) + run_info_from_obj = TrainRunInfo.parse_obj(dict_sample) + + # Test serialization equivalence + assert run_info_from_json == run_info_from_obj + + # Test dict deserialization equivalence + assert run_info_from_json.dict() == dict_sample + + # Test json deserialization equivalence + assert json.loads(run_info_from_json.json()) == json.loads(json_sample) + + # Test constructors equivalence + assert _get_run_info_sample() == run_info_from_json + + +def test_state_actor_api(): + state_actor = get_or_create_state_actor() + named_actors = ray.util.list_named_actors(all_namespaces=True) + assert { + "name": TRAIN_STATE_ACTOR_NAME, + "namespace": TRAIN_STATE_ACTOR_NAMESPACE, + } in named_actors + + # Concurrently register 100 runs + num_runs = 100 + info_list = [_get_run_info_sample(run_id=str(i)) for i in range(num_runs)] + ray.get([state_actor.register_train_run.remote(run) for run in info_list]) + + # Test get all runs + train_runs = ray.get(state_actor.get_all_train_runs.remote()) + assert len(train_runs) == num_runs + + # Test get a single run by run_id + for i in range(num_runs): + run_info = ray.get(state_actor.get_train_run.remote(run_id=str(i))) + assert run_info == info_list[i] + + +def test_state_manager(ray_start_gpu_cluster): + worker_group = WorkerGroup(num_workers=4, resources_per_worker={"GPU": 1}) + + # No errors raised if TrainStateActor is not started + state_manager = TrainRunStateManager(state_actor=None) + state_manager.register_train_run( + run_id="run_id", + run_name="run_name", + job_id="0000000001", + controller_actor_id="3abd1972a19148d78acc78dd9414736e", + datasets={}, + worker_group=worker_group, + ) + + # Register 100 runs with 10 TrainRunStateManagers + state_actor = get_or_create_state_actor() + for i in range(10): + state_manager = TrainRunStateManager(state_actor=state_actor) + for j in range(10): + run_id = i * 10 + j + state_manager.register_train_run( + run_id=str(run_id), + run_name="run_name", + job_id="0000000001", + controller_actor_id="3abd1972a19148d78acc78dd9414736e", + datasets={ + "train": ray.data.from_items(list(range(4))), + "eval": ray.data.from_items(list(range(4))), + }, + worker_group=worker_group, + ) + + runs = ray.get(state_actor.get_all_train_runs.remote()) + assert len(runs) == 100 + + for i in range(100): + run_id = str(i) + run_info = ray.get(state_actor.get_train_run.remote(run_id=run_id)) + assert run_info and run_info.id == run_id + + +@pytest.mark.parametrize("gpus_per_worker", [0, 1, 2]) +def test_track_e2e_training(ray_start_gpu_cluster, gpus_per_worker): + os.environ["RAY_TRAIN_ENABLE_STATE_TRACKING"] = "1" + num_workers = 4 + run_name = "test" + datasets = { + "train": ray.data.from_items(list(range(4))), + "eval": ray.data.from_items(list(range(4))), + } + + if gpus_per_worker == 0: + use_gpu = False + resources_per_worker = {"CPU": 1} + else: + use_gpu = True + resources_per_worker = {"GPU": gpus_per_worker} + + trainer = DataParallelTrainer( + train_loop_per_worker=lambda: None, + run_config=RunConfig(name=run_name), + scaling_config=ScalingConfig( + num_workers=num_workers, + use_gpu=use_gpu, + resources_per_worker=resources_per_worker, + ), + datasets=datasets, + ) + + trainer.fit() + + state_actor = ray.get_actor( + name=TRAIN_STATE_ACTOR_NAME, namespace=TRAIN_STATE_ACTOR_NAMESPACE + ) + + runs = ray.get(state_actor.get_all_train_runs.remote()) + run_id = next(iter(runs.keys())) + run = next(iter(runs.values())) + + # Check Run Info + assert run.id == run_id + assert run.name == run_name + assert len(run.workers) == num_workers + assert run.controller_actor_id and run.job_id + + world_ranks = [worker.world_rank for worker in run.workers] + local_ranks = [worker.local_rank for worker in run.workers] + node_ranks = [worker.node_rank for worker in run.workers] + + # Ensure that the workers are sorted by global rank + assert world_ranks == [0, 1, 2, 3] + assert local_ranks == [0, 1, 2, 3] + assert node_ranks == [0, 0, 0, 0] + + # Check GPU ids + gpu_ids = [worker.gpu_ids for worker in run.workers] + if gpus_per_worker == 0: + assert gpu_ids == [[], [], [], []] + elif gpus_per_worker == 1: + assert gpu_ids == [[0], [1], [2], [3]] + elif gpus_per_worker == 2: + flat_gpu_ids = set() + for ids in gpu_ids: + flat_gpu_ids.update(ids) + assert flat_gpu_ids == set(range(8)) + + # Check Datasets + for dataset_info in run.datasets: + dataset = datasets[dataset_info.name] + assert dataset_info.dataset_name == dataset._plan._dataset_name + assert dataset_info.dataset_uuid == dataset._plan._dataset_uuid + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", "-x", __file__]))