Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR/Train] Make Dataset ingest configurable #24066

Merged
merged 18 commits into from
Apr 28, 2022

Conversation

amogkam
Copy link
Contributor

@amogkam amogkam commented Apr 21, 2022

Refactors Dataset splitting to make it less hacky and address the TODO. Also makes Dataset ingest in general configurable for Ray Train. This is an internal only change for now, but will set the stage for the proposed ingest API

Customizable ingest for GBDT Trainers is out of scope for this PR.

Why are these changes needed?

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

python/ray/train/utils.py Outdated Show resolved Hide resolved
python/ray/train/utils.py Outdated Show resolved Hide resolved
else:
# Ray Train will strip out the added string before exposing to users.
updated_dataset_dict[key + "_NO-SHARD"] = value
def dataset_split_fn(dataset_dict, training_worker_handles):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we pull this out into a default splitting function in the util module?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're pulling this out into a default splitting function, could you add a docstring? Would allow readers to understand the function without having to reference _RayDatasetSpec.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separated into its own function, but left it in data_parallel_trainer for now as it is only being used in DataParallelTrainer. Let's revisit the location if we need it in more trainers in the future.

)
if not len(splits) == len(training_worker_handles):
raise RuntimeError(
"The list of Datasets returned by the "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about moving this class into a separate file, such as ml/train/impl/dataset_spec.py?

(also, this should go into ml/train for now right? given that ray/train is deprecated).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on moving dataset spec to it's own module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to its own module!

But kept it as part of ray/train. It's being used by current Ray Train, and as discussed offline, the end state is to eventually move ray/ml/train to ray/train anyways.

@@ -320,12 +325,14 @@ def run(

train_func = construct_train_func(train_func, config)

dataset_spec = _RayDatasetSpec(dataset_or_dict=dataset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point we should move / copy the trainer files into ml/train right? In preparation for replacing the old train module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we decided the end state should be to move ray/ml/train to ray/train right?

But yes agreed, we should definitely cleanup the current ray/train in a future PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely need the cleanup. I think the easiest way is to copy files over and decouple them, but open to other approaches.

@ericl ericl added the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Apr 21, 2022
@jovany-wang jovany-wang self-assigned this Apr 21, 2022
Copy link
Member

@bveeramani bveeramani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some readability comments, but overall looks good.

else:
# Ray Train will strip out the added string before exposing to users.
updated_dataset_dict[key + "_NO-SHARD"] = value
def dataset_split_fn(dataset_dict, training_worker_handles):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a type annotation to training_worker_handles? The type wasn't obvious until I read _RayDatasetSpec.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

else:
# Ray Train will strip out the added string before exposing to users.
updated_dataset_dict[key + "_NO-SHARD"] = value
def dataset_split_fn(dataset_dict, training_worker_handles):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're pulling this out into a default splitting function, could you add a docstring? Would allow readers to understand the function without having to reference _RayDatasetSpec.

locality_hints=training_worker_handles,
)
else:
# Only shard the training dataset.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining why we're only sharding the training dataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

python/ray/train/utils.py Outdated Show resolved Hide resolved
)
if not len(splits) == len(training_worker_handles):
raise RuntimeError(
"The list of Datasets returned by the "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on moving dataset spec to it's own module.

]
] = None

def _default_split_fn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused why we need both _RayDataSpec._default_split_fn and dataset_split_fn in training_loop. Isn't detaset_split_fn the default for training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset_split_fn is the implementation that DataParallelTrainer uses, but is not the default for RayDatasetSpec in general.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More specifically, the default implementation for RayDatasetSpec is to split all datasets.

DataParallelTrainer is overriding this behavior to split just the train dataset, but not split the other datasets.

In the future, users should be able to override the behavior for DataParallelTrainer.

@amogkam amogkam removed the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Apr 25, 2022
@@ -348,3 +343,39 @@ def write_checkpoint(self, checkpoint: Dict):
@property
def latest_checkpoint_dir(self) -> Optional[Path]:
raise NotImplementedError


def _default_dataset_split_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move this into the dataset spec file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is specific to DataParallelTrainer, not to DatasetSpec in general.

@ericl ericl added the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Apr 25, 2022
@amogkam amogkam requested a review from ericl April 25, 2022 23:30
@amogkam amogkam removed the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Apr 25, 2022
@ericl ericl added the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Apr 26, 2022
Copy link
Member

@bveeramani bveeramani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@amogkam amogkam added tests-ok The tagger certifies test failures are unrelated and assumes personal liability. and removed @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. labels Apr 28, 2022
@amogkam amogkam merged commit 629424f into ray-project:master Apr 28, 2022
@amogkam amogkam deleted the air-dataset-split-refactor branch April 28, 2022 04:41
krfricke added a commit that referenced this pull request Apr 29, 2022
After #24066, some release tests are running into:

```
ModuleNotFoundError: No module named 'ray.train.impl'
```

This PR simply adds a `__init__.py` file to resolve this.

We also add a 5 wecond delay for client runners in release test to give clusters a bit of slack to come up (and avoid ray client connection errors)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants