Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Issue 21991: Fix SampleBatch slicing for SampleBatch.INFOS in RNN cases #22050

Merged
merged 3 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rllib/policy/rnn_sequencing.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def pad_batch_to_sequences_of_same_size(
elif (
not feature_keys
and not k.startswith("state_out_")
and k not in ["infos", SampleBatch.SEQ_LENS]
and k not in [SampleBatch.INFOS, SampleBatch.SEQ_LENS]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

):
feature_keys_.append(k)

Expand Down
22 changes: 13 additions & 9 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,26 +936,30 @@ def _slice(self, slice_: slice) -> "SampleBatch":
# Build our slice-map, if not done already.
if not self._slice_map:
sum_ = 0
for i, l in enumerate(self[SampleBatch.SEQ_LENS]):
for _ in range(l):
self._slice_map.append((i, sum_))
sum_ += l
Comment on lines -939 to -942
Copy link
Contributor Author

@XuehaiPan XuehaiPan Feb 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python's built-in type int is immutable, so it's okay to use x += y. Statement x += y will create a new int instance.

>>> x = 1000

>>> x
1000
>>> id(x)
140102686402576

>>> x += 1
>>> x
1001
>>> id(x)
140102686403152

Here, l and sum_ are tensors. sum_ += l will not create a new instance. slice_map.append((i, sum_)) is adding the reference to the same instance sum_ in memory.

>>> import torch
>>> sum_ = 0
>>> l = torch.ones((), dtype=torch.int32)

>>> sum_
0
>>> id(sum_)
94230645960128

>>> sum_ += l
>>> sum_
tensor(1, dtype=torch.int32)
>>> id(sum_)
140100675095104

>>> sum_ += l
>>> sum_
tensor(2, dtype=torch.int32)
>>> id(sum_)
140100675095104

for i, l in enumerate(map(int, self[SampleBatch.SEQ_LENS])):
self._slice_map.extend([(i, sum_)] * l)
sum_ = sum_ + l
# In case `stop` points to the very end (lengths of this
# batch), return the last sequence (the -1 here makes sure we
# never go beyond it; would result in an index error below).
self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_))

start_seq_len, start = self._slice_map[start]
stop_seq_len, stop = self._slice_map[stop]
start_seq_len, start_unpadded = self._slice_map[start]
stop_seq_len, stop_unpadded = self._slice_map[stop]
start_padded = start_unpadded
stop_padded = stop_unpadded
if self.zero_padded:
start = start_seq_len * self.max_seq_len
stop = stop_seq_len * self.max_seq_len
start_padded = start_seq_len * self.max_seq_len
stop_padded = stop_seq_len * self.max_seq_len

def map_(path, value):
if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith(
"state_in_"
):
return value[start:stop]
if path[0] != SampleBatch.INFOS:
return value[start_padded:stop_padded]
else:
return value[start_unpadded:stop_unpadded]
else:
return value[start_seq_len:stop_seq_len]
Comment on lines 955 to 964
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Special case for SampleBatch.INFOS (unpadded).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for this fix.


Expand Down