Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add Segmentation Buffer for DT #27829

Merged
merged 7 commits into from
Aug 16, 2022

Conversation

charlesjsun
Copy link
Contributor

Signed-off-by: Charles Sun [email protected]

Why are these changes needed?

Added the SegmentationBuffer that DT (Decision Transformer) needs.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@charlesjsun charlesjsun changed the title Added segmentation buffer [RLlib] Added Segmentation Buffer for DT Aug 12, 2022
# decision transformer
RETURNS_TO_GO = "returns_to_go"
ATTENTION_MASKS = "attention_masks"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguments could be made for putting it here, or putting it elsewhere like in dt.py which doesn't exist yet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're putting this here.

Signed-off-by: Charles Sun <[email protected]>
@charlesjsun charlesjsun changed the title [RLlib] Added Segmentation Buffer for DT [RLlib] Add Segmentation Buffer for DT Aug 15, 2022
Copy link
Member

@avnishn avnishn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor changes

# TODO: sample proportional to episode length
# Sample a random episode from the buffer and then sample a random
# segment from that episode.
buffer_ind = np.random.randint(0, len(self._buffer))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not a seedable call. Can you instantiate your class with a numpy rng and set the seed of the rng using the seed in the global config?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved


obs = episode[SampleBatch.OBS][si:ei]
actions = episode[SampleBatch.ACTIONS][si:ei]
# Note that returns-to-go needs an extra elem as the target for the last action.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah nice

[returns_to_go, np.zeros((1, 1), dtype=returns_to_go.dtype)], axis=0
)

# Front-pad if at beginning of rollout.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add note that this is about inference :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

offset = min(self.max_seq_len, ep_len)
# We allow si to be negative (for now) because we want segments that only
# contains the first few transitions (and padd the rest),
# for example [0, 0, 0, 0, 0, 0, R0, s0, a0].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

# add to buffer and check that only last one is kept (due to replacement)
buffer.add(batch)

assert len(_get_internal_buffer(buffer)) == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs message

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

# decision transformer
RETURNS_TO_GO = "returns_to_go"
ATTENTION_MASKS = "attention_masks"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're putting this here.

Signed-off-by: Charles Sun <[email protected]>
Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@charlesjsun Thanks for taking care of the todos :)

rllib/algorithms/dt/segmentation_buffer.py Show resolved Hide resolved
rllib/algorithms/dt/segmentation_buffer.py Outdated Show resolved Hide resolved
rllib/algorithms/dt/segmentation_buffer.py Outdated Show resolved Hide resolved
# TODO: sample proportional to episode length
# Sample a random episode from the buffer and then sample a random
# segment from that episode.
buffer_ind = np.random.randint(0, len(self._buffer))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

rllib/algorithms/dt/segmentation_buffer.py Show resolved Hide resolved
rllib/algorithms/dt/segmentation_buffer.py Outdated Show resolved Hide resolved
rllib/algorithms/dt/tests/test_segmentation_buffer.py Outdated Show resolved Hide resolved
rllib/algorithms/dt/tests/test_segmentation_buffer.py Outdated Show resolved Hide resolved
@kouroshHakha
Copy link
Contributor

@richardliaw This can be merged. It passes all the tests.

@richardliaw richardliaw merged commit 753fad9 into ray-project:master Aug 16, 2022
Stefan-1313 pushed a commit to Stefan-1313/ray_mod that referenced this pull request Aug 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants