Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train][Observability] Track Train Run Info with TrainStateActor #44585

Merged
merged 29 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9997c2c
add lightning version restrictions
woshiyyya Apr 8, 2024
5141441
init
woshiyyya Apr 9, 2024
3f48f50
fix
woshiyyya Apr 9, 2024
4a7657e
update
woshiyyya Apr 10, 2024
d9e69cb
start state actor
woshiyyya Apr 15, 2024
dfe6e9d
fix dataset id
woshiyyya Apr 15, 2024
7c87556
add dataset schema
woshiyyya Apr 15, 2024
ce168b2
Merge remote-tracking branch 'upstream/master' into train/add_train_r…
woshiyyya Apr 15, 2024
da9e812
change run_id to trial_info
woshiyyya Apr 15, 2024
3c868e2
add trial run id property
woshiyyya Apr 15, 2024
5a86c59
add TrainStatManager
woshiyyya Apr 15, 2024
e044776
update TrainRunStatsManager
woshiyyya Apr 16, 2024
7359531
Merge branch 'master' into train/add_train_run_schema
woshiyyya Apr 16, 2024
c1ca02a
fix import
woshiyyya Apr 16, 2024
bab8c0b
fix ci
woshiyyya Apr 16, 2024
ead3f62
fix ci
woshiyyya Apr 16, 2024
7a27dc9
Apply suggestions from code review
woshiyyya Apr 17, 2024
787dc09
update code structure
woshiyyya Apr 17, 2024
8fad340
fix typo
woshiyyya Apr 17, 2024
a9e47f7
fix lint
woshiyyya Apr 17, 2024
86b40f9
address comments
woshiyyya Apr 22, 2024
b92d05f
fix
woshiyyya Apr 22, 2024
5713d65
fix circular import
woshiyyya Apr 22, 2024
ac9b42b
add test & launch stateactor on driver
woshiyyya Apr 24, 2024
22a0606
Merge branch 'master' into train/add_train_run_schema
woshiyyya Apr 24, 2024
218713c
update tests
woshiyyya Apr 24, 2024
843d748
add state manager test
woshiyyya Apr 25, 2024
c627925
address comments
woshiyyya Apr 26, 2024
486ac2c
clean ut import
woshiyyya Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar

from python.ray.train._internal.stats import TrainRunStatsManager

import ray
import ray._private.ray_constants as ray_constants
from ray._private.ray_constants import env_integer
Expand All @@ -23,6 +25,7 @@
from ray.train.backend import BackendConfig
from ray.train.constants import (
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
ENABLE_RAY_TRAIN_DASHBOARD_ENV,
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
TRAIN_ENABLE_WORKER_SPREAD_ENV,
Expand Down Expand Up @@ -118,6 +121,8 @@ def __init__(
)
]

self.dashboard_enabled = env_integer(ENABLE_RAY_TRAIN_DASHBOARD_ENV, 0)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

def start(
self,
initialization_hook: Optional[Callable[[], None]] = None,
Expand Down Expand Up @@ -194,6 +199,10 @@ def _set_driver_dataset_context(ctx: DataContext):
self._increment_failures()
self._restart()

# Setup StatsActorManager for Ray Train Dashboard
if self.dashboard_enabled:
self.stats_manager = TrainRunStatsManager()

def _create_placement_group(self):
"""Creates a placement group if it does not exist.

Expand Down Expand Up @@ -432,7 +441,6 @@ def start_training(
data_config: DataConfig,
storage: StorageContext,
checkpoint: Optional[Checkpoint] = None,
on_session_init: Callable[[], None] = None,
) -> None:
"""Executes a training function on all workers in a separate thread.

Expand Down Expand Up @@ -528,8 +536,19 @@ def initialize_session(

self.get_with_failure_handling(futures)

if on_session_init:
on_session_init()
# Register Train Run before training starts
if self.dashboard_enabled:
session = get_session()
trainer_actor_id = ray.runtime_context.get_runtime_context().get_actor_id()

self.stats_manager.register_train_run(
run_id=session.run_id,
run_name=session.experiment_name,
trial_name=session.trial_name,
trainer_actor_id=trainer_actor_id,
datasets=datasets,
worker_group=self.worker_group,
)

# Run the training function asynchronously in its own thread.
def train_async():
Expand Down
63 changes: 63 additions & 0 deletions python/ray/train/_internal/schema.py
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import List, Optional

from ray._private.pydantic_compat import BaseModel, Field
from ray.util.annotations import DeveloperAPI

try:
import pydantic # noqa: F401
except ImportError:
raise ModuleNotFoundError(
"pydantic isn't installed. "
"To install pydantic, please run 'pip install pydantic'"
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
)


@DeveloperAPI
class TrainWorkerInfo(BaseModel):
"""Metadata of a Ray Train worker."""

actor_id: str = Field(description="Actor ID of the worker.")
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
world_rank: int = Field(description="World rank.")
local_rank: int = Field(description="Local rank.")
node_rank: int = Field(description="Node rank.")
gpu_ids: Optional[List[str]] = Field(
description="A list of GPU ids allocated to that worker."
)
node_id: Optional[str] = Field(
description="ID of the node that the worker is running on."
)
node_ip: Optional[str] = Field(
description="IP address of the node that the worker is running on."
)
pid: Optional[str] = Field(description="PID of the worker.")
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
justinvyu marked this conversation as resolved.
Show resolved Hide resolved


@DeveloperAPI
class TrainDatasetInfo(BaseModel):
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
name: str = Field(
description="The key of the dataset dict specified in Ray Train Trainer."
)
plan_name: str = Field(description="The name of the internal dataset plan.")
plan_uuid: str = Field(description="The uuid of the internal dataset plan.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the plan name if we have the UUID?

Copy link
Member Author

@woshiyyya woshiyyya Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to observability team, the dataset_tag is they primary key they use in Dashboard.

Also, Data team are using dataset_tag = f"{plan_name}_{plan_uuid} as a unique id for a dataset.

metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid)

Copy link
Member Author

@woshiyyya woshiyyya Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separating plan_name and plan_uuid also makes it easy to support stream_split_op in the future.

def _get_dataset_tag(self):
return create_dataset_tag(
self._base_dataset._plan._dataset_name,
self._base_dataset._uuid,
self._output_split_idx,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we ask Ray Data team for a better way to get the tag?



@DeveloperAPI
class TrainRunInfo(BaseModel):
"""Metadata for a Ray Train run and information about its workers."""

name: str = Field(description="The name of the Train run.")
id: str = Field(description="The unique identifier for each Train run.")
job_id: str = Field(description="Ray Job ID.")
trial_name: str = Field(
description=(
"Trial name. It should be different among different Train runs, "
"except for those that are restored from checkpoints."
)
)
trainer_actor_id: str = Field(description="Actor Id of the Trainer.")
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
workers: List[TrainWorkerInfo] = Field(
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
description="A List of Train workers sorted by global ranks."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consideration for elastic training: workers might get added and removed from failures.

Copy link
Member Author

@woshiyyya woshiyyya Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. We can expose an API in TrainStatsActor to update the TrainRunInfo during training.

For schema, we can add a dead_workers entry to track those evicted workers.

datasets: List[TrainDatasetInfo] = Field(
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
description="A List of dataset info for this Train run."
)
21 changes: 18 additions & 3 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Type

import ray
from ray.air._internal.session import _get_session
Expand Down Expand Up @@ -60,7 +60,9 @@ class TrialInfo:
resources: Dict[str, float]
logdir: str
driver_ip: str
run_id: str
experiment_name: Optional[str] = None
datasets_info: Optional[List[Dict[str, Any]]] = None


class _FutureTrainingResult:
Expand Down Expand Up @@ -461,6 +463,10 @@ def trial_name(self) -> str:
def trial_id(self) -> str:
return self.trial_info.id

@property
def run_id(self) -> str:
return self.trial_info.run_id

@property
def trial_resources(self) -> "PlacementGroupFactory":
return self.trial_info.resources
Expand All @@ -478,7 +484,8 @@ def get_dataset_shard(
warnings.warn(
"No dataset passed in. Returning None. Make sure to "
"pass in a Dataset to Trainer.run to use this "
"function."
"function.",
stacklevel=2,
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
)
elif isinstance(shard, dict):
if not dataset_name:
Expand Down Expand Up @@ -646,7 +653,8 @@ def wrapper(*args, **kwargs):
warnings.warn(
f"`{fn_name}` is meant to only be "
"called inside a function that is executed by a Tuner"
f" or Trainer. Returning `{default_value}`."
f" or Trainer. Returning `{default_value}`.",
stacklevel=2,
)
return default_value
return fn(*args, **kwargs)
Expand Down Expand Up @@ -818,6 +826,13 @@ def get_trial_id() -> str:
return _get_session().trial_id


@PublicAPI(stability="alpha")
@_warn_session_misuse()
def get_run_id() -> str:
"""Unique Train Run id for the corresponding trial."""
return _get_session().run_id
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_resources() -> "PlacementGroupFactory":
Expand Down
117 changes: 117 additions & 0 deletions python/ray/train/_internal/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
import os
import threading
from typing import Dict, Optional

import ray
from ray.data import Dataset
from ray.train._internal.schema import TrainDatasetInfo, TrainRunInfo, TrainWorkerInfo
from ray.train._internal.utils import check_for_failure
from ray.train._internal.worker_group import WorkerGroup

logger = logging.getLogger(__name__)


@ray.remote(num_cpus=0)
class TrainStatsActor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading over this, I realized there are no stats, and this points to a bunch of Train...Info objects. I wonder if there is a word that we can use instead of Stats/Info that captures/unifies these two.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, currently there's only static metadata stored in this actor, and in the future there will be stats.

What about State?

def __init__(self):
self.train_runs = dict()
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

def register_train_run(self, run_info: TrainRunInfo):
# Register a new train run.
self.train_runs[run_info.id] = run_info
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

def get_train_run(self, run_id: str) -> Optional[TrainRunInfo]:
# Retrieve a registered run with its id
return self.train_runs.get(run_id, None)

def get_all_train_runs(self) -> Dict[str, TrainRunInfo]:
# Retrieve all registered train runs
return self.train_runs


TRAIN_STATS_ACTOR_NAME = "train_stats_actor"
TRAIN_STATS_ACTOR_NAMESPACE = "_train_stats_actor"

_stats_actor_lock: threading.RLock = threading.RLock()


def get_or_launch_stats_actor():
"""Create or launch a `TrainStatsActor` on the head node."""
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
with _stats_actor_lock:
return TrainStatsActor.options(
name=TRAIN_STATS_ACTOR_NAME,
namespace=TRAIN_STATS_ACTOR_NAMESPACE,
get_if_exists=True,
lifetime="detached",
resources={"node:__internal_head__": 0.001},
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
).remote()


class TrainRunStatsManager:
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
"""A class that aggregates and reports train run info to TrainStatsActor.

This manager class is created on the train controller layer for each run.
"""

def __init__(self) -> None:
self.stats_actor = get_or_launch_stats_actor()

def register_train_run(
Copy link
Member Author

@woshiyyya woshiyyya Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to make sure all the pydantic related code is encapsulated in TrainRunStatsManager. So that OSS users who are not using ray[default] will not get an error.

self,
run_id: str,
run_name: str,
trial_name: str,
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
trainer_actor_id: str,
datasets: Dict[str, Dataset],
worker_group: WorkerGroup,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think it feels a bit weird to pass the WorkerGroup here, but not sure if there is another cleaner way to organize it.

Copy link
Member Author

@woshiyyya woshiyyya Apr 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The consideration here is: when we do elastic/fault-tolerant training, we can avoid using an old workergroup if we pass it through the function arguments.

) -> None:
"""Collect Train Run Info and report to StatsActor."""

def collect_train_worker_info():
train_context = ray.train.get_context()
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
core_context = ray.runtime_context.get_runtime_context()

return TrainWorkerInfo(
world_rank=train_context.get_world_rank(),
local_rank=train_context.get_local_rank(),
node_rank=train_context.get_node_rank(),
actor_id=core_context.get_actor_id(),
node_id=core_context.get_node_id(),
node_ip=core_context.get_node_ip_address(),
gpu_ids=core_context.get_accelerator_ids().get("GPU", []),
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
pid=os.getpid(),
)

futures = [
worker_group.execute_single_async(index, collect_train_worker_info)
for index in range(len(worker_group))
]
success, exception = check_for_failure(futures)

if not success:
logger.warning("Failed to collect infomation for Ray Train Worker.")
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
return

worker_info_list = ray.get(futures)
worker_info_list = sorted(worker_info_list, key=lambda info: info.world_rank)

dataset_info_list = [
TrainDatasetInfo(
name=ds_name,
plan_name=ds._plan._dataset_name,
plan_uuid=ds._plan._dataset_uuid,
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
)
for ds_name, ds in datasets.items()
]

train_run_info = TrainRunInfo(
id=run_id,
name=run_name,
trial_name=trial_name,
trainer_actor_id=trainer_actor_id,
workers=worker_info_list,
datasets=dataset_info_list,
)

self.stats_actor.register_train_run.remote(train_run_info)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ def training_loop(self):
f"Invalid trainer type. You are attempting to restore a trainer of type"
f" {trainer_cls} with `{cls.__name__}.restore`, "
"which will most likely fail. "
f"Use `{trainer_cls.__name__}.restore` instead."
f"Use `{trainer_cls.__name__}.restore` instead.",
stacklevel=2,
)

original_datasets = param_dict.pop("datasets", {})
Expand Down
5 changes: 5 additions & 0 deletions python/ray/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def _get_ray_train_session_dir() -> str:
# Defaults to 0, which always retries on node preemption failures.
RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE = "RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE"

# Set this to 1 to start a StatsActor and collect infomation for Ray Train Dashboard
# Defaults to 0
ENABLE_RAY_TRAIN_DASHBOARD_ENV = "ENABLE_RAY_TRAIN_DASHBOARD_ENV"

# NOTE: When adding a new environment variable, please track it in this list.
TRAIN_ENV_VARS = {
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
Expand All @@ -87,6 +91,7 @@ def _get_ray_train_session_dir() -> str:
TRAIN_ENABLE_WORKER_SPREAD_ENV,
RAY_CHDIR_TO_TRIAL_DIR,
RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE,
ENABLE_RAY_TRAIN_DASHBOARD_ENV,
}

# Key for AIR Checkpoint metadata in TrainingResult metadata
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,12 @@ def training_loop(self) -> None:
trial_info = TrialInfo(
name=session.get_trial_name(),
id=session.get_trial_id(),
run_id=session.get_run_id(),
resources=session.get_trial_resources(),
logdir=session.get_trial_dir(),
driver_ip=ray.util.get_node_ip_address(),
experiment_name=session.get_experiment_name(),
datasets_info=self.datasets_info,
)

backend_executor = self._backend_executor_cls(
Expand Down
2 changes: 2 additions & 0 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import logging
import os
import uuid
from functools import partial
from numbers import Number
from typing import Any, Callable, Dict, Optional, Type
Expand Down Expand Up @@ -58,6 +59,7 @@ def setup(self, config):
logdir=self._storage.trial_driver_staging_path,
driver_ip=None,
experiment_name=self._storage.experiment_dir_name,
run_id=uuid.uuid4().hex,
Copy link
Member Author

@woshiyyya woshiyyya Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create a unique ID that differentiate each run.

  • Cannot use trial id because trainer.restore will reuse trial id.
  • Cannot use job id because there could be multiple train runs in one job.
  • Cannot use trial id + job id because one can restore a run multiple times in one job.

justinvyu marked this conversation as resolved.
Show resolved Hide resolved
),
storage=self._storage,
synchronous_result_reporting=True,
Expand Down
Loading