-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR/Train] Make Dataset ingest configurable #24066
[AIR/Train] Make Dataset ingest configurable #24066
Conversation
Co-authored-by: Eric Liang <[email protected]>
else: | ||
# Ray Train will strip out the added string before exposing to users. | ||
updated_dataset_dict[key + "_NO-SHARD"] = value | ||
def dataset_split_fn(dataset_dict, training_worker_handles): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we pull this out into a default splitting function in the util module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're pulling this out into a default splitting function, could you add a docstring? Would allow readers to understand the function without having to reference _RayDatasetSpec
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separated into its own function, but left it in data_parallel_trainer
for now as it is only being used in DataParallelTrainer
. Let's revisit the location if we need it in more trainers in the future.
…nto air-dataset-split-refactor
python/ray/train/utils.py
Outdated
) | ||
if not len(splits) == len(training_worker_handles): | ||
raise RuntimeError( | ||
"The list of Datasets returned by the " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about moving this class into a separate file, such as ml/train/impl/dataset_spec.py
?
(also, this should go into ml/train
for now right? given that ray/train
is deprecated).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on moving dataset spec to it's own module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to its own module!
But kept it as part of ray/train
. It's being used by current Ray Train, and as discussed offline, the end state is to eventually move ray/ml/train
to ray/train
anyways.
@@ -320,12 +325,14 @@ def run( | |||
|
|||
train_func = construct_train_func(train_func, config) | |||
|
|||
dataset_spec = _RayDatasetSpec(dataset_or_dict=dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point we should move / copy the trainer files into ml/train
right? In preparation for replacing the old train
module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we decided the end state should be to move ray/ml/train to ray/train right?
But yes agreed, we should definitely cleanup the current ray/train in a future PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely need the cleanup. I think the easiest way is to copy files over and decouple them, but open to other approaches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some readability comments, but overall looks good.
else: | ||
# Ray Train will strip out the added string before exposing to users. | ||
updated_dataset_dict[key + "_NO-SHARD"] = value | ||
def dataset_split_fn(dataset_dict, training_worker_handles): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a type annotation to training_worker_handles
? The type wasn't obvious until I read _RayDatasetSpec
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added!
else: | ||
# Ray Train will strip out the added string before exposing to users. | ||
updated_dataset_dict[key + "_NO-SHARD"] = value | ||
def dataset_split_fn(dataset_dict, training_worker_handles): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're pulling this out into a default splitting function, could you add a docstring? Would allow readers to understand the function without having to reference _RayDatasetSpec
.
locality_hints=training_worker_handles, | ||
) | ||
else: | ||
# Only shard the training dataset. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment explaining why we're only sharding the training dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added!
python/ray/train/utils.py
Outdated
) | ||
if not len(splits) == len(training_worker_handles): | ||
raise RuntimeError( | ||
"The list of Datasets returned by the " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on moving dataset spec to it's own module.
python/ray/train/utils.py
Outdated
] | ||
] = None | ||
|
||
def _default_split_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused why we need both _RayDataSpec._default_split_fn
and dataset_split_fn
in training_loop
. Isn't detaset_split_fn
the default for training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataset_split_fn
is the implementation that DataParallelTrainer
uses, but is not the default for RayDatasetSpec
in general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More specifically, the default implementation for RayDatasetSpec
is to split all datasets.
DataParallelTrainer
is overriding this behavior to split just the train dataset, but not split the other datasets.
In the future, users should be able to override the behavior for DataParallelTrainer
.
@@ -348,3 +343,39 @@ def write_checkpoint(self, checkpoint: Dict): | |||
@property | |||
def latest_checkpoint_dir(self) -> Optional[Path]: | |||
raise NotImplementedError | |||
|
|||
|
|||
def _default_dataset_split_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we move this into the dataset spec file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is specific to DataParallelTrainer
, not to DatasetSpec
in general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…dataset-split-refactor
…dataset-split-refactor
After #24066, some release tests are running into: ``` ModuleNotFoundError: No module named 'ray.train.impl' ``` This PR simply adds a `__init__.py` file to resolve this. We also add a 5 wecond delay for client runners in release test to give clusters a bit of slack to come up (and avoid ray client connection errors)
Refactors Dataset splitting to make it less hacky and address the TODO. Also makes Dataset ingest in general configurable for Ray Train. This is an internal only change for now, but will set the stage for the proposed ingest API
Customizable ingest for GBDT Trainers is out of scope for this PR.
Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.