-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] wrap BackendExecutor in ray.remote() #20123
Conversation
return next_results | ||
|
||
def _fetch_next_result(self) -> Optional[List[Dict]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With these changes, the "processing" of TrainingResults are happening in different places: reports are processed (yielded) in next
, while checkpoints are processed in either _fetch_next_result
or _finish_checkpointing
.
Is it possible to refactor this a bit so that all of the processing is happening in one place?
One possible implementation:
So we can have one method that just obtains the next TrainingResult
from the actor:
def get_next_result(self) -> Optional[List[TrainingResult]]
And then we can move most of the logic in next
(or use some other helper method)
def _get_next_result(self) -> Optional[List[TrainingResult]]:
results = ray.get(self._executor.get_next_results.remote())
return results
def _pause_reporting(self):
ray.get(self._executor.pause_reporting.remote())
def _finish_training(self):
# Assumes that all reporting and checkpointing are already finished.
ray.get(self._executor.finish_training.remote())
def __next__(self):
while True:
if self.is_finished():
raise StopIteration
next_results = self._run_with_error_handling(self._get_next_result)
if next_results is None:
# There are no more reports or checkpoints.
# So we don't need to pause reporting here.
try:
self._final_results = self._run_with_error_handling(
self._finish_training)
finally:
self._finished_training = True
else:
first_result = next_results[0]
result_type = first_result.type
if result_type is TrainingResultType.REPORT:
result_data = [r.data for r in results]
yield result_data
elif result_type is TrainingResultType.CHECKPOINT:
self._checkpoint_manager._process_checkpoint(results)
# Iterate until next REPORT call or training has finished.
else:
raise TrainBackendError(f"Unexpected result type: "
f"{result_type}. "
f"Expected one of "
f"{[type in TrainingResultType]}")
def get_final_results(self, force) -> List[T]:
if not self.is_finished():
assert self._final_results is None
if force:
# Pause reporting.
self._run_with_error_handling(self.pause_reporting)
# Iterate and process remaining checkpoints.
# This will also set self._final_results and self._finished_training
for _ in self:
pass
assert self.is_finished
else:
logger.info("Please finish iterating through the "
"intermediate results before getting the"
"final returns. If you would like "
"training to finish immediately and get "
"the final returns, then set "
"`force=True`.")
return self._final_results
And we no longer need _fetch_next_result
or _finish_checkpointing
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah while I do agree we can do something like this, I don't actually understand the original concern:
With these changes, the "processing" of TrainingResults are happening in different places: reports are processed (yielded) in next, while checkpoints are processed in either _fetch_next_result or _finish_checkpointing.
This logic was just moved up from BackendExecutor.fetch_next_result and BackendExecutor.finish_training.
Do you think this refactoring should be done in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s just to simplify the logic and the abstractions. Previously the abstraction was that BackendExecutor would only return report results that the trainer will consume. But now that the trainer is also doing checkpointing, we don’t need an abstraction to only return report results- it just feels like extra redirection to me.
What do you think about the code snippet above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it’s a pretty quick fix then we can do it in this pr. If not then we can do it in a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried making this change and one of the tests (test_worker_failure_2
) started hanging 😅 I might have missed one of the error handling wrappers or something.
Can I follow up with this in a separate PR? I'd like to spend some more time thinking about this in general as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok sounds good, let's do this refactor in a separate PR. Should we add a TODO or track this somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created a ticket to track this here: #20330
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just left some minor comments
return next_results | ||
|
||
def _fetch_next_result(self) -> Optional[List[Dict]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok sounds good, let's do this refactor in a separate PR. Should we add a TODO or track this somehow?
@@ -1146,11 +1025,12 @@ def train_actor_failure(): | |||
with patch.object( | |||
new_backend_executor_cls, | |||
"_get_dataset_shards", | |||
return_value=dataset_splits) as mock_method: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
return next_results | ||
|
||
def _fetch_next_result(self) -> Optional[List[Dict]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, sounds good!
Why are these changes needed?
This PR allows Ray Train to be run with Ray Client. Wrapping the
BackendExecutor
allows the primary execution to occur on the cluster, whileTrainer
remains on the driver.Changes
BackendExecutor
inray.remote()
.force_on_current_node
to ensure this is scheduled on the head node in Ray Client mode.self._executor.XYZ()
toray.get(self._executor.XYZ.remote())
. Wrapping inray.get
allows for synchronous execution.checkpoint_manager
fromBackendExecutor
toTrainer
. For Ray Client mode, persisted checkpoints will be written to disk of the driver.TrainingResult
s up fromBackendExecutor
toTrainer
. TheTrainer
will then process the report/checkpoint results within_fetch_next_result
.finish_training
fromBackendExecutor
toTrainer
. This allows theTrainer
to call theCheckpointManager
while flushing the result queue.BackendExecutor
.test_trainer.py
into a newtest_examples.py
.test_trainer
exceeds the 900 second timeout.Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.