-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Add support for handling multiple batch data types for prepare_data_loader #26386
Conversation
…oader in ray.train.torch
Co-authored-by: matthewdeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @VishDev12! Do you think you can also add a test for this? There's already a test_torch_prepare_dataloader
test in test_gpu.py
. We can just use that same test but try with DataLoaders that also output dictionaries instead of just tuples.
Hey @amogkam, I've added a dict case in the Please let me know if something doesn't look quite right. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @VishDev12 this lgtm! Just left one minor comment
@@ -24,6 +24,11 @@ def __len__(self): | |||
return len(self.x) | |||
|
|||
|
|||
class LinearDatasetDict(LinearDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is only used in the test and not part of the example, can we move this to test_gpu
directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, I've moved it just under the fixtures defined on test_gpu
, just to keep the separation between the test_*
functions and the rest. But if you want to have it defined just above test_torch_prepare_dataloader
, I could do that too.
Awesome, thanks @VishDev12! |
…types for prepare_data_loader"" (#26491) Signed-off-by: Amog Kamsetty <[email protected]> * Revert "Revert "[Train] Add support for handling multiple batch data types for prepare_data_loader (#26386)" (#26483)" This reverts commit e6c0403.
…types for prepare_data_loader"" (ray-project#26491) Signed-off-by: Amog Kamsetty <[email protected]> * Revert "Revert "[Train] Add support for handling multiple batch data types for prepare_data_loader (ray-project#26386)" (ray-project#26483)" This reverts commit e6c0403. Signed-off-by: Rohan138 <[email protected]>
…e_data_loader (ray-project#26386) When working with Ray Train, using the ray.train.torch.prepare_data_loader method with a dataset that returns a dictionary instead of a tuple from its __getitem__ method causes issues. Co-authored-by: matthewdeng <[email protected]> Signed-off-by: Stefan van der Kleij <[email protected]>
…r prepare_data_loader (ray-project#26386)" (ray-project#26483) This reverts commit 36229d1. Signed-off-by: Stefan van der Kleij <[email protected]>
…types for prepare_data_loader"" (ray-project#26491) Signed-off-by: Amog Kamsetty <[email protected]> * Revert "Revert "[Train] Add support for handling multiple batch data types for prepare_data_loader (ray-project#26386)" (ray-project#26483)" This reverts commit e6c0403. Signed-off-by: Stefan van der Kleij <[email protected]>
Why are these changes needed?
When working with Ray Train, using the
ray.train.torch.prepare_data_loader
method with a dataset that returns a dictionary instead of a tuple from its__getitem__
method causes issues.Related issue number
Closes #26385
Checks
scripts/format.sh
to lint the changes in this PR.