-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Add Segmentation Buffer for DT #27829
[RLlib] Add Segmentation Buffer for DT #27829
Conversation
# decision transformer | ||
RETURNS_TO_GO = "returns_to_go" | ||
ATTENTION_MASKS = "attention_masks" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arguments could be made for putting it here, or putting it elsewhere like in dt.py
which doesn't exist yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we're putting this here.
Signed-off-by: Charles Sun <[email protected]>
9f8ce0b
to
d6d446c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor changes
# TODO: sample proportional to episode length | ||
# Sample a random episode from the buffer and then sample a random | ||
# segment from that episode. | ||
buffer_ind = np.random.randint(0, len(self._buffer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not a seedable call. Can you instantiate your class with a numpy rng and set the seed of the rng using the seed in the global config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved
|
||
obs = episode[SampleBatch.OBS][si:ei] | ||
actions = episode[SampleBatch.ACTIONS][si:ei] | ||
# Note that returns-to-go needs an extra elem as the target for the last action. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah nice
[returns_to_go, np.zeros((1, 1), dtype=returns_to_go.dtype)], axis=0 | ||
) | ||
|
||
# Front-pad if at beginning of rollout. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add note that this is about inference :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved
offset = min(self.max_seq_len, ep_len) | ||
# We allow si to be negative (for now) because we want segments that only | ||
# contains the first few transitions (and padd the rest), | ||
# for example [0, 0, 0, 0, 0, 0, R0, s0, a0]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved
# add to buffer and check that only last one is kept (due to replacement) | ||
buffer.add(batch) | ||
|
||
assert len(_get_internal_buffer(buffer)) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needs message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved
# decision transformer | ||
RETURNS_TO_GO = "returns_to_go" | ||
ATTENTION_MASKS = "attention_masks" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we're putting this here.
Signed-off-by: Charles Sun <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@charlesjsun Thanks for taking care of the todos :)
# TODO: sample proportional to episode length | ||
# Sample a random episode from the buffer and then sample a random | ||
# segment from that episode. | ||
buffer_ind = np.random.randint(0, len(self._buffer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Signed-off-by: Charles Sun <[email protected]>
Signed-off-by: Charles Sun <[email protected]>
Signed-off-by: Charles Sun <[email protected]>
Signed-off-by: Charles Sun <[email protected]>
@richardliaw This can be merged. It passes all the tests. |
Signed-off-by: Stefan van der Kleij <[email protected]>
Signed-off-by: Charles Sun [email protected]
Why are these changes needed?
Added the SegmentationBuffer that DT (Decision Transformer) needs.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.