Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Policy.compute_actions_from_input_dict does not properly track accessed fields for the Policy's view-requirements. #14385

Closed
2 tasks done
sven1977 opened this issue Feb 26, 2021 · 1 comment · Fixed by #15856
Assignees
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks

Comments

@sven1977
Copy link
Contributor

sven1977 commented Feb 26, 2021

Policy.compute_actions_from_input_dict does not properly track accessed fields for the Policy's view-requirements.

What is the problem?

Ray version and other system information (Python version, TensorFlow version, OS):

Reproduction (REQUIRED)

Please provide a short code snippet (less than 50 lines if possible) that can be copy-pasted to reproduce the issue. The snippet should have no external library dependencies (i.e., use fake or mock data / environments):

If the code snippet cannot be run by itself, the issue will be closed with "needs-repro-script".

To reproduce:

import torch
import numpy as np

from ray.rllib.agents.ppo import PPOTorchPolicy, PPOTrainer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import add_mixins
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
    convert_to_torch_tensor


class MixIn:
    def compute_actions_from_input_dict(self, input_dict, explore=None, timestep=None, episodes=None, **kwargs):

        #if not self.config['env_config'].get('invalid_action_masking', False):
        #    raise RuntimeError('invalid_action_masking must be set to True in env_config to use this mixin')

        infos = input_dict[SampleBatch.INFOS]

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        with torch.no_grad():
            # Pass lazy (torch) tensor dict to Model as `input_dict`.
            input_dict = self._lazy_tensor_dict(input_dict)
            # Pack internal state inputs into (separate) list.
            state_batches = [
                input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
            ]
            # Calculate RNN sequence lengths.
            seq_lens = np.array([1] * len(input_dict["obs"])) \
                if state_batches else None

            # Call the exploration before_compute_actions hook.
            self.exploration.before_compute_actions(
                explore=explore, timestep=timestep)

            dist_inputs, state_out = self.model(input_dict, state_batches,
                                                seq_lens)
            # Extract the tree from the info batch
            batch_size = input_dict["obs"].shape[0]
            return convert_to_non_torch_type((
                torch.zeros((batch_size,), dtype=torch.int32),
                state_out,
                {
                    "vf_preds": torch.zeros((batch_size, ), dtype=torch.float32),
                    "action_dist_inputs": dist_inputs,
                    "action_logp": torch.ones((batch_size,), dtype=torch.float32)
                }
            ))


import ray
ray.init(local_mode=True)

MyPolicy = add_mixins(PPOTorchPolicy, [MixIn])
MyTrainer = PPOTrainer.with_updates(default_policy=MyPolicy, get_policy_class=lambda c: MyPolicy)

trainer = MyTrainer(env="CartPole-v0", config={
    "framework": "torch",
    "num_workers": 0,
})

for _ in range(20):
    print(trainer.train())
  • I have verified my script runs in a clean environment and reproduces the issue.
  • I have verified the issue also occurs with the latest wheels.
@sven1977 sven1977 added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) P2 Important issue, but not time-critical rllib and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Feb 26, 2021
@sven1977 sven1977 self-assigned this Feb 26, 2021
@sven1977 sven1977 added P1 Issue that should be fixed within a few weeks and removed P2 Important issue, but not time-critical labels Mar 11, 2021
@ericl ericl added this to the RLlib Bugs milestone Mar 11, 2021
@ericl ericl removed the rllib label Mar 11, 2021
@Bam4d
Copy link
Contributor

Bam4d commented May 17, 2021

@sven1977 This issue is happening again on the current master of RLLib

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants