-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Fix SampleBatch to_device() #27572
[RLlib] Fix SampleBatch to_device() #27572
Conversation
This reverts commit 74686a8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a couple of minor questions.
looks cleeeaaaan!
# Floatify all float64 tensors. | ||
if tensor.dtype == torch.double: | ||
if tensor.is_floating_point(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any chance this is a new api? if yes, maybe make sure it works with the version we pin for core.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. I checked quickly and it's even there on torch=1.8.0 Where do we specify the torch pinned version in ray?
https://pytorch.org/docs/1.8.0/search.html?q=IS_FLOATING_POINT&check_keywords=yes&area=default
# Numpy arrays. | ||
if isinstance(item, np.ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what were we doing before ... didn't we check these conditions before we get in here :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the else is added because I moved a bit of logic to clean the code. e.g torch.is_tensor() is brought down after dealing with RepeatedValues type.
"f": RepeatedValues(np.array([[1, 2, 0, 0]]), lengths=[2], max_len=4), | ||
SampleBatch.SEQ_LENS: np.array([2, 3, 1]), | ||
"state_in_0": np.array([1.0, 3.0, 4.0]), | ||
SampleBatch.INFOS: np.array([{"a": 1}, {"b": 2}, {"c": 3}]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
INFOS's dtype is object right? and we basically don't do anything to it if I understand correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. correct.
Signed-off-by: Huaiwei Sun <[email protected]>
Signed-off-by: Stefan van der Kleij <[email protected]>
Why are these changes needed?
Related issue number
closes #26593
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.