-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rllib] Support prev_state/prev_action in rollout and fix multiagent #4565
Merged
Merged
Changes from 7 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
8b9be61
Cleaner and more correct treatment of agent states in rollout.py
vladfi1 dac943b
support lstm_use_prev_action_reward in rollout.py
vladfi1 35b735a
Linter.
vladfi1 ae0bc2c
appease flake8
vladfi1 578fb14
Use _DUMMY_AGENT_ID instead of 0.
vladfi1 8d05633
All agents have a policy_agent_mapping.
vladfi1 50adcc9
Update rollout.py
ericl 73f0167
Fix rollout.py for single-agent envs.
vladfi1 486ad06
Merge branch 'rollout-fix' of github.com.:vladfi1/ray into rollout-fix
vladfi1 5fd9ef5
Use agent_id, not policy_id.
vladfi1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,16 @@ | |
from __future__ import print_function | ||
|
||
import argparse | ||
import collections | ||
import json | ||
import os | ||
import pickle | ||
|
||
import gym | ||
import ray | ||
from ray.rllib.agents.registry import get_agent_class | ||
from ray.rllib.env import MultiAgentEnv | ||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID | ||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID | ||
from ray.tune.util import merge_dicts | ||
|
||
|
@@ -102,17 +105,35 @@ def run(args, parser): | |
rollout(agent, args.env, num_steps, args.out, args.no_render) | ||
|
||
|
||
class DefaultMapping(collections.defaultdict): | ||
"""default_factory now takes as an argument the missing key.""" | ||
|
||
def __missing__(self, key): | ||
self[key] = value = self.default_factory(key) | ||
return value | ||
|
||
|
||
def default_policy_agent_mapping(unused_agent_id): | ||
return DEFAULT_POLICY_ID | ||
|
||
|
||
def rollout(agent, env_name, num_steps, out=None, no_render=True): | ||
policy_agent_mapping = default_policy_agent_mapping | ||
|
||
if hasattr(agent, "local_evaluator"): | ||
env = agent.local_evaluator.env | ||
multiagent = agent.local_evaluator.multiagent | ||
if multiagent: | ||
multiagent = isinstance(env, MultiAgentEnv) | ||
if agent.local_evaluator.multiagent: | ||
policy_agent_mapping = agent.config["multiagent"][ | ||
"policy_mapping_fn"] | ||
mapping_cache = {} | ||
|
||
policy_map = agent.local_evaluator.policy_map | ||
state_init = {p: m.get_initial_state() for p, m in policy_map.items()} | ||
use_lstm = {p: len(s) > 0 for p, s in state_init.items()} | ||
action_init = { | ||
p: m.action_space.sample() | ||
for p, m in policy_map.items() | ||
} | ||
else: | ||
env = gym.make(env_name) | ||
multiagent = False | ||
|
@@ -122,39 +143,49 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): | |
rollouts = [] | ||
steps = 0 | ||
while steps < (num_steps or steps + 1): | ||
mapping_cache = {} # in case policy_agent_mapping is stochastic | ||
if out is not None: | ||
rollout = [] | ||
state = env.reset() | ||
obs = env.reset() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should the mapping cache be reset as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think that's appropriate. |
||
multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs} | ||
agent_states = DefaultMapping( | ||
lambda agent_id: state_init[mapping_cache[agent_id]]) | ||
prev_actions = DefaultMapping( | ||
lambda agent_id: action_init[mapping_cache[agent_id]]) | ||
prev_rewards = collections.defaultdict(lambda: 0.) | ||
done = False | ||
reward_total = 0.0 | ||
while not done and steps < (num_steps or steps + 1): | ||
action_dict = {} | ||
for agent_id, a_obs in multi_obs.items(): | ||
if a_obs is not None: | ||
policy_id = mapping_cache.setdefault( | ||
agent_id, policy_agent_mapping(agent_id)) | ||
p_use_lstm = use_lstm[policy_id] | ||
if p_use_lstm: | ||
a_action, p_state, _ = agent.compute_action( | ||
a_obs, | ||
state=agent_states[agent_id], | ||
prev_action=prev_actions[agent_id], | ||
prev_reward=prev_rewards[agent_id], | ||
policy_id=policy_id) | ||
agent_states[policy_id] = p_state | ||
else: | ||
a_action = agent.compute_action( | ||
a_obs, | ||
prev_action=prev_actions[agent_id], | ||
prev_reward=prev_rewards[agent_id], | ||
policy_id=policy_id) | ||
action_dict[agent_id] = a_action | ||
prev_actions[agent_id] = a_action | ||
action = action_dict | ||
|
||
next_obs, reward, done, _ = env.step(action) | ||
if multiagent: | ||
action_dict = {} | ||
for agent_id in state.keys(): | ||
a_state = state[agent_id] | ||
if a_state is not None: | ||
policy_id = mapping_cache.setdefault( | ||
agent_id, policy_agent_mapping(agent_id)) | ||
p_use_lstm = use_lstm[policy_id] | ||
if p_use_lstm: | ||
a_action, p_state_init, _ = agent.compute_action( | ||
a_state, | ||
state=state_init[policy_id], | ||
policy_id=policy_id) | ||
state_init[policy_id] = p_state_init | ||
else: | ||
a_action = agent.compute_action( | ||
a_state, policy_id=policy_id) | ||
action_dict[agent_id] = a_action | ||
action = action_dict | ||
for agent_id, r in reward.items(): | ||
prev_rewards[agent_id] = r | ||
else: | ||
if use_lstm[DEFAULT_POLICY_ID]: | ||
action, state_init, _ = agent.compute_action( | ||
state, state=state_init) | ||
else: | ||
action = agent.compute_action(state) | ||
|
||
next_state, reward, done, _ = env.step(action) | ||
prev_rewards[_DUMMY_AGENT_ID] = reward | ||
|
||
if multiagent: | ||
done = done["__all__"] | ||
|
@@ -164,9 +195,9 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): | |
if not no_render: | ||
env.render() | ||
if out is not None: | ||
rollout.append([state, action, next_state, reward, done]) | ||
rollout.append([obs, action, next_obs, reward, done]) | ||
steps += 1 | ||
state = next_state | ||
obs = next_obs | ||
if out is not None: | ||
rollouts.append(rollout) | ||
print("Episode reward", reward_total) | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that in the actual code we feed the all-zeros action initially
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that make sense? All zeros might not even be in the action space?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's a good question.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine for now, though it would be better to set it to zeros for consistency (or switch the sampler to stick in a random initial action, but that might be weird).