-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Issue 21991: Fix SampleBatch
slicing for SampleBatch.INFOS
in RNN cases
#22050
[RLlib] Issue 21991: Fix SampleBatch
slicing for SampleBatch.INFOS
in RNN cases
#22050
Conversation
for i, l in enumerate(self[SampleBatch.SEQ_LENS]): | ||
for _ in range(l): | ||
self._slice_map.append((i, sum_)) | ||
sum_ += l |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python's built-in type int
is immutable, so it's okay to use x += y
. Statement x += y
will create a new int
instance.
>>> x = 1000
>>> x
1000
>>> id(x)
140102686402576
>>> x += 1
>>> x
1001
>>> id(x)
140102686403152
Here, l
and sum_
are tensors. sum_ += l
will not create a new instance. slice_map.append((i, sum_))
is adding the reference to the same instance sum_
in memory.
>>> import torch
>>> sum_ = 0
>>> l = torch.ones((), dtype=torch.int32)
>>> sum_
0
>>> id(sum_)
94230645960128
>>> sum_ += l
>>> sum_
tensor(1, dtype=torch.int32)
>>> id(sum_)
140100675095104
>>> sum_ += l
>>> sum_
tensor(2, dtype=torch.int32)
>>> id(sum_)
140100675095104
def map_(path, value): | ||
if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith( | ||
"state_in_" | ||
): | ||
return value[start:stop] | ||
if path[0] != SampleBatch.INFOS: | ||
return value[start_padded:stop_padded] | ||
else: | ||
return value[start_unpadded:stop_unpadded] | ||
else: | ||
return value[start_seq_len:stop_seq_len] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Special case for SampleBatch.INFOS
(unpadded).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks for this fix.
@@ -730,6 +730,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> Dict[str, TensorType] | |||
if data_col | |||
in [ | |||
SampleBatch.OBS, | |||
SampleBatch.INFOS, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In method get_inference_input_dict()
, the key INFOS
should be shift togather with ODS
/ T
.
# initial step
obs = env.reset()
info = {}
t = -1
state = policy.model.get_initial_state()
results = policy.compute_single_action(obs, # obs at `t = -1`
info=info, # info at `t = -1`
state=state)
action, state, *_ = results
# interaction (t = 0)
obs, reward, done, info = env.step(action)
t = 0
results = policy.compute_single_action(obs, # obs at `t = 0`
info=info, # info at `t = 0`
state=state)
action, state, *_ = results
# interaction (t = 1)
obs, reward, done, info = env.step(action)
t = 1
...
In buffer, for an environment transaction step tuple (obs, action, reward, next_obs, done, info)
, we have:
# time step
# t t t+1 t+1 t+1 t+1
(obs, action, next_obs, reward, done, info)
The values (next_obs, reward, done, info)
are the results of next time step res = env.step(action)
.
This change will cause inconsistent time step alignment for OBS
and INFOS
between inference and training:
# Inference
# feed (obs, infos) to policy.compute_action()
t_OBS = t_INFOS
# Train
# tuple (obs, action, next_obs, reward, done, info)
t_OBS = t_INFOS - 1
In the policy model, the user should have:
class CustomModel:
def __init__(self, ...):
# ...
self.view_requirements[SampleBatch.INFOS] = ViewRequirement()
def forward(self, input_dict, state, seq_lens):
if self.training:
# -> t_OBS + 1 = t_NEXT_OBS = t_INFOS
else: # sampling, postprocess_trajactory
# -> t_OBS = t_NEXT_OBS - 1 = t_INFOS
Alternative approach:
We can add new keys CUR_INFOS
and NEXT_INFOS
in SampleBatch
to distinguish the difference in sampling and training.
The key INFOS
always stands for the observation (NEXT_OBS
) returned from next_obs, r, d, info = env.step()
. In buffer, we have tuple:
(obs, action, next_obs, reward, done, info) -> (obs, action, next_obs, reward, done, next_info)
# In sampler
t_OBS = t_CUR_INFOS = t_INFOS - 1
t_NEXT_OBS = t_NEXT_INFOS = t_INFOS # invalid
# In trainer
t_OBS = t_CUR_INFOS = t_INFOS - 1
t_NEXT_OBS = t_NEXT_INFOS = t_INFOS
Hi, @sven1977, could you have a look at this PR? Thanks! |
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.
|
7c1caff
to
021ac08
Compare
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.
|
Signed-off-by: Xuehai Pan <[email protected]>
…ence_input_dict()`" This reverts commit 47149e354ff9024809763e7e585732dbc67c231e. Signed-off-by: Xuehai Pan <[email protected]>
021ac08
to
cec66c5
Compare
@@ -115,7 +115,7 @@ def pad_batch_to_sequences_of_same_size( | |||
elif ( | |||
not feature_keys | |||
and not k.startswith("state_out_") | |||
and k not in ["infos", SampleBatch.SEQ_LENS] | |||
and k not in [SampleBatch.INFOS, SampleBatch.SEQ_LENS] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR @XuehaiPan and the detailed explanations! :)
Why are these changes needed?
In RNN cases, the sample batch will be padded and stacked into
torch.Tensor
before feeding into methodmodel.forward
andmodel.custom_loss
, exceptINFOS
andSEQ_LENS
:ray/rllib/policy/rnn_sequencing.py
Lines 109 to 120 in b73a007
The keys
SampleBatch.INFOS
,SampleBatch.SEQ_LENS
, and"state_*"
will not be padded.Therefore key
SampleBatch.INFOS
needs to be specially treated as well during batch slicing (batch[start:end]
).ray/rllib/policy/sample_batch.py
Lines 924 to 930 in 9c95b9a
In
SampleBatch._slice_map
:ray/rllib/policy/sample_batch.py
Lines 906 to 916 in 9c95b9a
The value for key
SampleBatch.SEQ_LENS
is a tensor. And variablel
is an integer tensor instead of a Pythonint
. The instancesum_
can be a tensor and will be reused. Because in-place operator+=
is used.See [RLlib] Issue 21991: Fix
SampleBatch
slicing forSampleBatch.INFOS
in RNN cases #22050 (comment)Related issue number
Fixes #21991
Checks
scripts/format.sh
to lint the changes in this PR.