Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Support prev_state/prev_action in rollout and fix multiagent #4565

Merged
merged 10 commits into from
Apr 10, 2019
93 changes: 62 additions & 31 deletions python/ray/rllib/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from __future__ import print_function

import argparse
import collections
import json
import os
import pickle

import gym
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.tune.util import merge_dicts

Expand Down Expand Up @@ -102,17 +105,35 @@ def run(args, parser):
rollout(agent, args.env, num_steps, args.out, args.no_render)


class DefaultMapping(collections.defaultdict):
"""default_factory now takes as an argument the missing key."""

def __missing__(self, key):
self[key] = value = self.default_factory(key)
return value


def default_policy_agent_mapping(unused_agent_id):
return DEFAULT_POLICY_ID


def rollout(agent, env_name, num_steps, out=None, no_render=True):
policy_agent_mapping = default_policy_agent_mapping

if hasattr(agent, "local_evaluator"):
env = agent.local_evaluator.env
multiagent = agent.local_evaluator.multiagent
if multiagent:
multiagent = isinstance(env, MultiAgentEnv)
if agent.local_evaluator.multiagent:
policy_agent_mapping = agent.config["multiagent"][
"policy_mapping_fn"]
mapping_cache = {}

policy_map = agent.local_evaluator.policy_map
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
action_init = {
p: m.action_space.sample()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that in the actual code we feed the all-zeros action initially

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that make sense? All zeros might not even be in the action space?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good question.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for now, though it would be better to set it to zeros for consistency (or switch the sampler to stick in a random initial action, but that might be weird).

for p, m in policy_map.items()
}
else:
env = gym.make(env_name)
multiagent = False
Expand All @@ -122,39 +143,49 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
rollouts = []
steps = 0
while steps < (num_steps or steps + 1):
mapping_cache = {} # in case policy_agent_mapping is stochastic
if out is not None:
rollout = []
state = env.reset()
obs = env.reset()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the mapping cache be reset as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that's appropriate.

multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs}
agent_states = DefaultMapping(
lambda agent_id: state_init[mapping_cache[agent_id]])
prev_actions = DefaultMapping(
lambda agent_id: action_init[mapping_cache[agent_id]])
prev_rewards = collections.defaultdict(lambda: 0.)
done = False
reward_total = 0.0
while not done and steps < (num_steps or steps + 1):
action_dict = {}
for agent_id, a_obs in multi_obs.items():
if a_obs is not None:
policy_id = mapping_cache.setdefault(
agent_id, policy_agent_mapping(agent_id))
p_use_lstm = use_lstm[policy_id]
if p_use_lstm:
a_action, p_state, _ = agent.compute_action(
a_obs,
state=agent_states[agent_id],
prev_action=prev_actions[agent_id],
prev_reward=prev_rewards[agent_id],
policy_id=policy_id)
agent_states[policy_id] = p_state
else:
a_action = agent.compute_action(
a_obs,
prev_action=prev_actions[agent_id],
prev_reward=prev_rewards[agent_id],
policy_id=policy_id)
action_dict[agent_id] = a_action
prev_actions[agent_id] = a_action
action = action_dict

next_obs, reward, done, _ = env.step(action)
if multiagent:
action_dict = {}
for agent_id in state.keys():
a_state = state[agent_id]
if a_state is not None:
policy_id = mapping_cache.setdefault(
agent_id, policy_agent_mapping(agent_id))
p_use_lstm = use_lstm[policy_id]
if p_use_lstm:
a_action, p_state_init, _ = agent.compute_action(
a_state,
state=state_init[policy_id],
policy_id=policy_id)
state_init[policy_id] = p_state_init
else:
a_action = agent.compute_action(
a_state, policy_id=policy_id)
action_dict[agent_id] = a_action
action = action_dict
for agent_id, r in reward.items():
prev_rewards[agent_id] = r
else:
if use_lstm[DEFAULT_POLICY_ID]:
action, state_init, _ = agent.compute_action(
state, state=state_init)
else:
action = agent.compute_action(state)

next_state, reward, done, _ = env.step(action)
prev_rewards[_DUMMY_AGENT_ID] = reward

if multiagent:
done = done["__all__"]
Expand All @@ -164,9 +195,9 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
if not no_render:
env.render()
if out is not None:
rollout.append([state, action, next_state, reward, done])
rollout.append([obs, action, next_obs, reward, done])
steps += 1
state = next_state
obs = next_obs
if out is not None:
rollouts.append(rollout)
print("Episode reward", reward_total)
Expand Down