Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Experiment restore stress tests #33706

Merged
merged 32 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a6799c0
fix typo
justinvyu Mar 24, 2023
2e7927b
Add initial test
justinvyu Mar 24, 2023
3c6c7a1
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Mar 24, 2023
962612c
draft 2
justinvyu Mar 24, 2023
db9af9c
working version of tuner restore stress test
justinvyu Mar 24, 2023
64a915a
add case for trainer
justinvyu Mar 24, 2023
057639f
Fix lint
justinvyu Mar 25, 2023
000edba
minor fixes (wrap in main method)
justinvyu Mar 25, 2023
34653f0
move to air
justinvyu Mar 25, 2023
f9ad0a6
add to build
justinvyu Mar 25, 2023
3fd9aff
change some configs
justinvyu Mar 25, 2023
83b6025
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Mar 25, 2023
bfa3050
[no_early_kickoff] merge
justinvyu Mar 25, 2023
e149031
add helper file to the test srcs
justinvyu Mar 25, 2023
6cf3f8b
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Apr 10, 2023
309a5df
Fix test for trainer (don't serialize datasets)
justinvyu Apr 10, 2023
0c4dd6e
Improve some docstrings
justinvyu Apr 10, 2023
132e6e4
add csv datasource
justinvyu Apr 10, 2023
cc38a5d
fix total runtime calculation to account for early end
justinvyu Apr 10, 2023
83a174e
switch to using storage_path
justinvyu Apr 10, 2023
6ec2764
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Apr 10, 2023
1edfb52
fix for py37
justinvyu Apr 10, 2023
7b01dcd
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Apr 10, 2023
4aea844
Fix lint
justinvyu Apr 10, 2023
238be73
Address some style comments
justinvyu Apr 11, 2023
fae1aff
Some cleanup
justinvyu Apr 11, 2023
66d53cc
Fix test
justinvyu Apr 11, 2023
a398eea
Add some clarifying docstring
justinvyu Apr 11, 2023
f7eda3e
more clarifications
justinvyu Apr 11, 2023
b15a618
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Apr 11, 2023
f236ef9
address comments
justinvyu Apr 13, 2023
9de3ed9
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu Apr 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/ray/air/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ py_test(
deps = [":ml_lib"]
)

py_test(
name = "test_experiment_restore",
size = "medium",
srcs = [
"tests/test_experiment_restore.py",
"tests/_test_experiment_restore_run.py"
],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"]
)

py_test(
name = "test_errors",
size = "small",
Expand Down
183 changes: 183 additions & 0 deletions python/ray/air/tests/_test_experiment_restore_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import collections
import json
import os
from pathlib import Path
import random
import time
from typing import Dict, List, Optional

import ray
from ray import air, tune
from ray.air import Checkpoint, session
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.tune.experiment import Trial


RUNNER_TYPE = os.environ.get("RUNNER_TYPE", "trainer")
STORAGE_PATH = os.environ.get("STORAGE_PATH", "/tmp/ray_results")
EXP_NAME = os.environ.get("EXP_NAME", "restore_integration_test")
CALLBACK_DUMP_FILE = os.environ.get(
"CALLBACK_DUMP_FILE", "/tmp/callback_dump_file.json"
)
CSV_DATA_FILE = os.environ.get("CSV_DATA_FILE", "/tmp/dummy.csv")

TIME_PER_ITER_S = float(os.environ.get("TIME_PER_ITER_S", "0.5"))
NUM_TRIALS = int(os.environ.get("NUM_TRIALS", "1"))
MAX_CONCURRENT_TRIALS = int(os.environ.get("MAX_CONCURRENT_TRIALS", "2"))
ITERATIONS_PER_TRIAL = int(os.environ.get("ITERATIONS_PER_TRIAL", "64"))


class StatefulCallback(tune.Callback):
def __init__(self):
self._trial_iterations = collections.defaultdict(list)

def on_trial_result(
self,
iteration: int,
trials: List["Trial"],
trial: "Trial",
result: Dict,
**info,
):
self._trial_iterations[trial.trial_id].append(result["training_iteration"])

def on_experiment_end(self, trials: List["Trial"], **info):
# Save callback contents to file
with open(CALLBACK_DUMP_FILE, "w") as f:
json.dump(self.get_state(), f, indent=2)

def get_state(self) -> Optional[Dict]:
return {"trial_iters": self._trial_iterations.copy()}

def set_state(self, state: Dict):
self._trial_iterations = state["trial_iters"]


class StatefulSearcher(tune.search.Searcher):
def __init__(
self,
metric: Optional[str] = None,
mode: Optional[str] = None,
):
super().__init__(metric=metric, mode=mode)
self._trial_count = 0

def suggest(self, trial_id: str) -> Optional[Dict]:
self._trial_count += 1
return {"id": self._trial_count}

def on_trial_complete(
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
) -> None:
pass

def save(self, checkpoint_path: str):
with open(checkpoint_path, "w") as f:
json.dump({"trial_count": self._trial_count}, f)

def restore(self, checkpoint_path: str):
with open(checkpoint_path, "r") as f:
state = json.load(f)
self._trial_count = state["trial_count"]


def train_fn(config: dict, data: Optional[dict] = None):
checkpoint = session.get_checkpoint()
start = checkpoint.to_dict()["iteration"] + 1 if checkpoint else 1

training_started_marker = Path(
os.environ.get("RUN_STARTED_MARKER", "/tmp/does-not-exist")
)
if training_started_marker.exists():
# Multiple workers may be trying to delete the same marker
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of try ... except, can you just missing_ok=True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used that originally but seems like missing_ok was introduced in py38.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please except FileNotFoundError instead then?

try:
training_started_marker.unlink()
except FileNotFoundError:
pass

for iteration in range(start, ITERATIONS_PER_TRIAL + 1):
time.sleep(TIME_PER_ITER_S)

session.report(
{"score": random.random()},
checkpoint=Checkpoint.from_dict({"iteration": iteration}),
)


def tuner(experiment_path: str, run_config: air.RunConfig) -> tune.ResultGrid:
trainable = tune.with_resources(train_fn, resources={"CPU": 1})
trainable = tune.with_parameters(trainable, data={"dummy_data": [1, 2, 3]})

if tune.Tuner.can_restore(experiment_path):
tuner = tune.Tuner.restore(
experiment_path, trainable=trainable, resume_errored=True
)
else:
tuner = tune.Tuner(
trainable,
run_config=run_config,
tune_config=tune.TuneConfig(
num_samples=8,
max_concurrent_trials=2,
search_alg=StatefulSearcher(),
),
)

result_grid = tuner.fit()
return result_grid


def trainer(experiment_path: str, run_config: air.RunConfig) -> air.Result:
dataset_size = 128
num_workers = 4

def train_loop_per_worker(config):
# Wrap the other train_fn with a check for the dataset.
assert session.get_dataset_shard("train")
train_fn(config)

datasets = {
"train": ray.data.range(dataset_size),
"valid": ray.data.read_csv(CSV_DATA_FILE),
}

if DataParallelTrainer.can_restore(experiment_path):
trainer = DataParallelTrainer.restore(
experiment_path,
datasets=datasets,
train_loop_per_worker=train_loop_per_worker,
)
else:
trainer = DataParallelTrainer(
train_loop_per_worker,
datasets=datasets,
scaling_config=air.ScalingConfig(
num_workers=num_workers, trainer_resources={"CPU": 0}
),
run_config=run_config,
)

result = trainer.fit()
return result


if __name__ == "__main__":
experiment_path = os.path.join(STORAGE_PATH, EXP_NAME)

ray.init()

run_config = air.RunConfig(
storage_path=STORAGE_PATH,
name=EXP_NAME,
checkpoint_config=air.CheckpointConfig(num_to_keep=1),
callbacks=[StatefulCallback()],
)

if RUNNER_TYPE == "tuner":
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
tuner(experiment_path, run_config)
elif RUNNER_TYPE == "trainer":
trainer(experiment_path, run_config)
else:
raise NotImplementedError(
"`RUNNER_TYPE` environment var must be one of ['tuner', 'trainer']"
)
Loading