You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Policy.compute_actions_from_input_dict does not properly track accessed fields for the Policy's view-requirements.
What is the problem?
Ray version and other system information (Python version, TensorFlow version, OS):
Reproduction (REQUIRED)
Please provide a short code snippet (less than 50 lines if possible) that can be copy-pasted to reproduce the issue. The snippet should have no external library dependencies (i.e., use fake or mock data / environments):
If the code snippet cannot be run by itself, the issue will be closed with "needs-repro-script".
To reproduce:
import torch
import numpy as np
from ray.rllib.agents.ppo import PPOTorchPolicy, PPOTrainer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import add_mixins
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor
class MixIn:
def compute_actions_from_input_dict(self, input_dict, explore=None, timestep=None, episodes=None, **kwargs):
#if not self.config['env_config'].get('invalid_action_masking', False):
# raise RuntimeError('invalid_action_masking must be set to True in env_config to use this mixin')
infos = input_dict[SampleBatch.INFOS]
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
with torch.no_grad():
# Pass lazy (torch) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
# Pack internal state inputs into (separate) list.
state_batches = [
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
]
# Calculate RNN sequence lengths.
seq_lens = np.array([1] * len(input_dict["obs"])) \
if state_batches else None
# Call the exploration before_compute_actions hook.
self.exploration.before_compute_actions(
explore=explore, timestep=timestep)
dist_inputs, state_out = self.model(input_dict, state_batches,
seq_lens)
# Extract the tree from the info batch
batch_size = input_dict["obs"].shape[0]
return convert_to_non_torch_type((
torch.zeros((batch_size,), dtype=torch.int32),
state_out,
{
"vf_preds": torch.zeros((batch_size, ), dtype=torch.float32),
"action_dist_inputs": dist_inputs,
"action_logp": torch.ones((batch_size,), dtype=torch.float32)
}
))
import ray
ray.init(local_mode=True)
MyPolicy = add_mixins(PPOTorchPolicy, [MixIn])
MyTrainer = PPOTrainer.with_updates(default_policy=MyPolicy, get_policy_class=lambda c: MyPolicy)
trainer = MyTrainer(env="CartPole-v0", config={
"framework": "torch",
"num_workers": 0,
})
for _ in range(20):
print(trainer.train())
I have verified my script runs in a clean environment and reproduces the issue.
I have verified the issue also occurs with the latest wheels.
The text was updated successfully, but these errors were encountered:
sven1977
added
bug
Something that is supposed to be working; but isn't
triage
Needs triage (eg: priority, bug/not-bug, and owning component)
P2
Important issue, but not time-critical
rllib
and removed
triage
Needs triage (eg: priority, bug/not-bug, and owning component)
labels
Feb 26, 2021
Policy.compute_actions_from_input_dict
does not properly track accessed fields for the Policy's view-requirements.What is the problem?
Ray version and other system information (Python version, TensorFlow version, OS):
Reproduction (REQUIRED)
Please provide a short code snippet (less than 50 lines if possible) that can be copy-pasted to reproduce the issue. The snippet should have no external library dependencies (i.e., use fake or mock data / environments):
If the code snippet cannot be run by itself, the issue will be closed with "needs-repro-script".
To reproduce:
The text was updated successfully, but these errors were encountered: