Skip to content

Commit

Permalink
[RLlib] Use PyTorch vectorized max() and sum() in SampleBatch.__init_…
Browse files Browse the repository at this point in the history
…_ when possible (#28388)

Signed-off-by: Cassidy Laidlaw <[email protected]>
  • Loading branch information
cassidylaidlaw authored Jan 12, 2023
1 parent 9ab7a16 commit 2bc8837
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def attempt_count_timesteps(tensor_dict: dict):
and not (tf and tf.is_tensor(tensor_dict[SampleBatch.SEQ_LENS]))
and len(tensor_dict[SampleBatch.SEQ_LENS]) > 0
):
return sum(tensor_dict[SampleBatch.SEQ_LENS])
if torch and torch.is_tensor(tensor_dict[SampleBatch.SEQ_LENS]):
return tensor_dict[SampleBatch.SEQ_LENS].sum().item()
else:
return sum(tensor_dict[SampleBatch.SEQ_LENS])

for k, v in copy_.items():
assert isinstance(k, str), tensor_dict
Expand Down Expand Up @@ -269,7 +272,10 @@ def __init__(self, *args, **kwargs):
and not (tf and tf.is_tensor(seq_lens_))
and len(seq_lens_) > 0
):
self.max_seq_len = max(seq_lens_)
if torch and torch.is_tensor(seq_lens_):
self.max_seq_len = seq_lens_.max().item()
else:
self.max_seq_len = max(seq_lens_)

if self._is_training is None:
self._is_training = self.pop("is_training", False)
Expand Down

0 comments on commit 2bc8837

Please sign in to comment.