-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Experiment restore stress tests #33706
Merged
gjoliver
merged 32 commits into
ray-project:master
from
justinvyu:air/experiment_restore_tests
Apr 13, 2023
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
a6799c0
fix typo
justinvyu 2e7927b
Add initial test
justinvyu 3c6c7a1
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu 962612c
draft 2
justinvyu db9af9c
working version of tuner restore stress test
justinvyu 64a915a
add case for trainer
justinvyu 057639f
Fix lint
justinvyu 000edba
minor fixes (wrap in main method)
justinvyu 34653f0
move to air
justinvyu f9ad0a6
add to build
justinvyu 3fd9aff
change some configs
justinvyu 83b6025
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu bfa3050
[no_early_kickoff] merge
justinvyu e149031
add helper file to the test srcs
justinvyu 6cf3f8b
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu 309a5df
Fix test for trainer (don't serialize datasets)
justinvyu 0c4dd6e
Improve some docstrings
justinvyu 132e6e4
add csv datasource
justinvyu cc38a5d
fix total runtime calculation to account for early end
justinvyu 83a174e
switch to using storage_path
justinvyu 6ec2764
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu 1edfb52
fix for py37
justinvyu 7b01dcd
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu 4aea844
Fix lint
justinvyu 238be73
Address some style comments
justinvyu fae1aff
Some cleanup
justinvyu 66d53cc
Fix test
justinvyu a398eea
Add some clarifying docstring
justinvyu f7eda3e
more clarifications
justinvyu b15a618
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu f236ef9
address comments
justinvyu 9de3ed9
Merge branch 'master' of https://github.com/ray-project/ray into air/…
justinvyu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import collections | ||
import json | ||
import os | ||
from pathlib import Path | ||
import random | ||
import time | ||
from typing import Dict, List, Optional | ||
|
||
import ray | ||
from ray import air, tune | ||
from ray.air import Checkpoint, session | ||
from ray.train.data_parallel_trainer import DataParallelTrainer | ||
from ray.tune.experiment import Trial | ||
|
||
|
||
RUNNER_TYPE = os.environ.get("RUNNER_TYPE", "trainer") | ||
STORAGE_PATH = os.environ.get("STORAGE_PATH", "/tmp/ray_results") | ||
EXP_NAME = os.environ.get("EXP_NAME", "restore_integration_test") | ||
CALLBACK_DUMP_FILE = os.environ.get( | ||
"CALLBACK_DUMP_FILE", "/tmp/callback_dump_file.json" | ||
) | ||
CSV_DATA_FILE = os.environ.get("CSV_DATA_FILE", "/tmp/dummy.csv") | ||
|
||
TIME_PER_ITER_S = float(os.environ.get("TIME_PER_ITER_S", "0.5")) | ||
NUM_TRIALS = int(os.environ.get("NUM_TRIALS", "1")) | ||
MAX_CONCURRENT_TRIALS = int(os.environ.get("MAX_CONCURRENT_TRIALS", "2")) | ||
ITERATIONS_PER_TRIAL = int(os.environ.get("ITERATIONS_PER_TRIAL", "64")) | ||
|
||
|
||
class StatefulCallback(tune.Callback): | ||
def __init__(self): | ||
self._trial_iterations = collections.defaultdict(list) | ||
|
||
def on_trial_result( | ||
self, | ||
iteration: int, | ||
trials: List["Trial"], | ||
trial: "Trial", | ||
result: Dict, | ||
**info, | ||
): | ||
self._trial_iterations[trial.trial_id].append(result["training_iteration"]) | ||
|
||
def on_experiment_end(self, trials: List["Trial"], **info): | ||
# Save callback contents to file | ||
with open(CALLBACK_DUMP_FILE, "w") as f: | ||
json.dump(self.get_state(), f, indent=2) | ||
|
||
def get_state(self) -> Optional[Dict]: | ||
return {"trial_iters": self._trial_iterations.copy()} | ||
|
||
def set_state(self, state: Dict): | ||
self._trial_iterations = state["trial_iters"] | ||
|
||
|
||
class StatefulSearcher(tune.search.Searcher): | ||
def __init__( | ||
self, | ||
metric: Optional[str] = None, | ||
mode: Optional[str] = None, | ||
): | ||
super().__init__(metric=metric, mode=mode) | ||
self._trial_count = 0 | ||
|
||
def suggest(self, trial_id: str) -> Optional[Dict]: | ||
self._trial_count += 1 | ||
return {"id": self._trial_count} | ||
|
||
def on_trial_complete( | ||
self, trial_id: str, result: Optional[Dict] = None, error: bool = False | ||
) -> None: | ||
pass | ||
|
||
def save(self, checkpoint_path: str): | ||
with open(checkpoint_path, "w") as f: | ||
json.dump({"trial_count": self._trial_count}, f) | ||
|
||
def restore(self, checkpoint_path: str): | ||
with open(checkpoint_path, "r") as f: | ||
state = json.load(f) | ||
self._trial_count = state["trial_count"] | ||
|
||
|
||
def train_fn(config: dict, data: Optional[dict] = None): | ||
checkpoint = session.get_checkpoint() | ||
start = checkpoint.to_dict()["iteration"] + 1 if checkpoint else 1 | ||
|
||
training_started_marker = Path( | ||
os.environ.get("RUN_STARTED_MARKER", "/tmp/does-not-exist") | ||
) | ||
if training_started_marker.exists(): | ||
# Multiple workers may be trying to delete the same marker | ||
try: | ||
training_started_marker.unlink() | ||
except FileNotFoundError: | ||
pass | ||
|
||
for iteration in range(start, ITERATIONS_PER_TRIAL + 1): | ||
time.sleep(TIME_PER_ITER_S) | ||
|
||
session.report( | ||
{"score": random.random()}, | ||
checkpoint=Checkpoint.from_dict({"iteration": iteration}), | ||
) | ||
|
||
|
||
def tuner(experiment_path: str, run_config: air.RunConfig) -> tune.ResultGrid: | ||
trainable = tune.with_resources(train_fn, resources={"CPU": 1}) | ||
trainable = tune.with_parameters(trainable, data={"dummy_data": [1, 2, 3]}) | ||
|
||
if tune.Tuner.can_restore(experiment_path): | ||
tuner = tune.Tuner.restore( | ||
experiment_path, trainable=trainable, resume_errored=True | ||
) | ||
else: | ||
tuner = tune.Tuner( | ||
trainable, | ||
run_config=run_config, | ||
tune_config=tune.TuneConfig( | ||
num_samples=8, | ||
max_concurrent_trials=2, | ||
search_alg=StatefulSearcher(), | ||
), | ||
) | ||
|
||
result_grid = tuner.fit() | ||
return result_grid | ||
|
||
|
||
def trainer(experiment_path: str, run_config: air.RunConfig) -> air.Result: | ||
dataset_size = 128 | ||
num_workers = 4 | ||
|
||
def train_loop_per_worker(config): | ||
# Wrap the other train_fn with a check for the dataset. | ||
assert session.get_dataset_shard("train") | ||
train_fn(config) | ||
|
||
datasets = { | ||
"train": ray.data.range(dataset_size), | ||
"valid": ray.data.read_csv(CSV_DATA_FILE), | ||
} | ||
|
||
if DataParallelTrainer.can_restore(experiment_path): | ||
trainer = DataParallelTrainer.restore( | ||
experiment_path, | ||
datasets=datasets, | ||
train_loop_per_worker=train_loop_per_worker, | ||
) | ||
else: | ||
trainer = DataParallelTrainer( | ||
train_loop_per_worker, | ||
datasets=datasets, | ||
scaling_config=air.ScalingConfig( | ||
num_workers=num_workers, trainer_resources={"CPU": 0} | ||
), | ||
run_config=run_config, | ||
) | ||
|
||
result = trainer.fit() | ||
return result | ||
|
||
|
||
if __name__ == "__main__": | ||
experiment_path = os.path.join(STORAGE_PATH, EXP_NAME) | ||
|
||
ray.init() | ||
|
||
run_config = air.RunConfig( | ||
storage_path=STORAGE_PATH, | ||
name=EXP_NAME, | ||
checkpoint_config=air.CheckpointConfig(num_to_keep=1), | ||
callbacks=[StatefulCallback()], | ||
) | ||
|
||
if RUNNER_TYPE == "tuner": | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tuner(experiment_path, run_config) | ||
elif RUNNER_TYPE == "trainer": | ||
trainer(experiment_path, run_config) | ||
else: | ||
raise NotImplementedError( | ||
"`RUNNER_TYPE` environment var must be one of ['tuner', 'trainer']" | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of try ... except, can you just missing_ok=True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used that originally but seems like
missing_ok
was introduced in py38.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please except FileNotFoundError instead then?