Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Issue 21991: Fix SampleBatch slicing for SampleBatch.INFOS in RNN cases #22050

Merged
merged 3 commits into from
Apr 25, 2022

Conversation

XuehaiPan
Copy link
Contributor

@XuehaiPan XuehaiPan commented Feb 2, 2022

Why are these changes needed?

  1. In RNN cases, the sample batch will be padded and stacked into torch.Tensor before feeding into method model.forward and model.custom_loss, except INFOS and SEQ_LENS:

    # RNN, attention net, or multi-agent case.
    state_keys = []
    feature_keys_ = feature_keys or []
    for k, v in batch.items():
    if k.startswith("state_in_"):
    state_keys.append(k)
    elif (
    not feature_keys
    and not k.startswith("state_out_")
    and k not in ["infos", SampleBatch.SEQ_LENS]
    ):
    feature_keys_.append(k)

    The keys SampleBatch.INFOS, SampleBatch.SEQ_LENS, and "state_*" will not be padded.

    Therefore key SampleBatch.INFOS needs to be specially treated as well during batch slicing (batch[start:end]).

    def map_(path, value):
    if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith(
    "state_in_"
    ):
    return value[start:stop]
    else:
    return value[start_seq_len:stop_seq_len]

  2. In SampleBatch._slice_map:

    # Build our slice-map, if not done already.
    if not self._slice_map:
    sum_ = 0
    for i, l in enumerate(self[SampleBatch.SEQ_LENS]):
    for _ in range(l):
    self._slice_map.append((i, sum_))
    sum_ += l
    # In case `stop` points to the very end (lengths of this
    # batch), return the last sequence (the -1 here makes sure we
    # never go beyond it; would result in an index error below).
    self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_))

    The value for key SampleBatch.SEQ_LENS is a tensor. And variable l is an integer tensor instead of a Python int. The instance sum_ can be a tensor and will be reused. Because in-place operator += is used.

    Screenshot debugger

    Screenshot watch

  3. See [RLlib] Issue 21991: Fix SampleBatch slicing for SampleBatch.INFOS in RNN cases #22050 (comment)

Related issue number

Fixes #21991

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Comment on lines -909 to -942
for i, l in enumerate(self[SampleBatch.SEQ_LENS]):
for _ in range(l):
self._slice_map.append((i, sum_))
sum_ += l
Copy link
Contributor Author

@XuehaiPan XuehaiPan Feb 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python's built-in type int is immutable, so it's okay to use x += y. Statement x += y will create a new int instance.

>>> x = 1000

>>> x
1000
>>> id(x)
140102686402576

>>> x += 1
>>> x
1001
>>> id(x)
140102686403152

Here, l and sum_ are tensors. sum_ += l will not create a new instance. slice_map.append((i, sum_)) is adding the reference to the same instance sum_ in memory.

>>> import torch
>>> sum_ = 0
>>> l = torch.ones((), dtype=torch.int32)

>>> sum_
0
>>> id(sum_)
94230645960128

>>> sum_ += l
>>> sum_
tensor(1, dtype=torch.int32)
>>> id(sum_)
140100675095104

>>> sum_ += l
>>> sum_
tensor(2, dtype=torch.int32)
>>> id(sum_)
140100675095104

Comment on lines 925 to 964
def map_(path, value):
if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith(
"state_in_"
):
return value[start:stop]
if path[0] != SampleBatch.INFOS:
return value[start_padded:stop_padded]
else:
return value[start_unpadded:stop_unpadded]
else:
return value[start_seq_len:stop_seq_len]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Special case for SampleBatch.INFOS (unpadded).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for this fix.

@@ -730,6 +730,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> Dict[str, TensorType]
if data_col
in [
SampleBatch.OBS,
SampleBatch.INFOS,
Copy link
Contributor Author

@XuehaiPan XuehaiPan Feb 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In method get_inference_input_dict(), the key INFOS should be shift togather with ODS / T.

# initial step
obs = env.reset()
info = {}
t = -1
state = policy.model.get_initial_state()

results = policy.compute_single_action(obs,        # obs  at `t = -1`
                                       info=info,  # info at `t = -1`
                                       state=state)
action, state, *_ = results

# interaction (t = 0)
obs, reward, done, info = env.step(action)
t = 0
results = policy.compute_single_action(obs,        # obs  at `t = 0`
                                       info=info,  # info at `t = 0`
                                       state=state)
action, state, *_ = results

# interaction (t = 1)
obs, reward, done, info = env.step(action)
t = 1
...

In buffer, for an environment transaction step tuple (obs, action, reward, next_obs, done, info), we have:

# time step
#  t      t      t+1       t+1    t+1   t+1
 (obs, action, next_obs, reward, done, info)

The values (next_obs, reward, done, info) are the results of next time step res = env.step(action).

This change will cause inconsistent time step alignment for OBS and INFOS between inference and training:

# Inference
# feed (obs, infos) to policy.compute_action()
t_OBS = t_INFOS

# Train
# tuple (obs, action, next_obs, reward, done, info)
t_OBS = t_INFOS - 1

In the policy model, the user should have:

class CustomModel:
    def __init__(self, ...):
        # ...
        self.view_requirements[SampleBatch.INFOS] = ViewRequirement()

    def forward(self, input_dict, state, seq_lens):
        if self.training:
            # -> t_OBS + 1 = t_NEXT_OBS = t_INFOS
        else:  # sampling, postprocess_trajactory
            # -> t_OBS = t_NEXT_OBS - 1 = t_INFOS

Alternative approach:

We can add new keys CUR_INFOS and NEXT_INFOS in SampleBatch to distinguish the difference in sampling and training.

The key INFOS always stands for the observation (NEXT_OBS) returned from next_obs, r, d, info = env.step(). In buffer, we have tuple:

(obs, action, next_obs, reward, done, info) -> (obs, action, next_obs, reward, done, next_info)
# In sampler
t_OBS = t_CUR_INFOS = t_INFOS - 1
t_NEXT_OBS = t_NEXT_INFOS = t_INFOS  # invalid

# In trainer
t_OBS = t_CUR_INFOS = t_INFOS - 1
t_NEXT_OBS = t_NEXT_INFOS = t_INFOS

@XuehaiPan
Copy link
Contributor Author

Hi, @sven1977, could you have a look at this PR? Thanks!

@stale
Copy link

stale bot commented Mar 12, 2022

This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.

  • If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@stale stale bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Mar 12, 2022
@stale stale bot removed the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Mar 14, 2022
@stale
Copy link

stale bot commented Apr 14, 2022

This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.

  • If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@stale stale bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Apr 14, 2022
@stale stale bot removed stale The issue is stale. It will be closed within 7 days unless there are further conversation labels Apr 15, 2022
@@ -115,7 +115,7 @@ def pad_batch_to_sequences_of_same_size(
elif (
not feature_keys
and not k.startswith("state_out_")
and k not in ["infos", SampleBatch.SEQ_LENS]
and k not in [SampleBatch.INFOS, SampleBatch.SEQ_LENS]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR @XuehaiPan and the detailed explanations! :)

@sven1977 sven1977 merged commit 6087eda into ray-project:master Apr 25, 2022
@XuehaiPan XuehaiPan deleted the samplebatch-slicing-infos branch July 22, 2022 09:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RLlib][Bug] Wrong number of samples in batch[SampleBatch.INFOS] in RNN cases
2 participants