diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 2fd81b12032f..af12e9c594f6 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -5,6 +5,7 @@ from __future__ import print_function import argparse +import collections import json import os import pickle @@ -12,6 +13,8 @@ import gym import ray from ray.rllib.agents.registry import get_agent_class +from ray.rllib.env import MultiAgentEnv +from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID from ray.tune.util import merge_dicts @@ -102,17 +105,35 @@ def run(args, parser): rollout(agent, args.env, num_steps, args.out, args.no_render) +class DefaultMapping(collections.defaultdict): + """default_factory now takes as an argument the missing key.""" + + def __missing__(self, key): + self[key] = value = self.default_factory(key) + return value + + +def default_policy_agent_mapping(unused_agent_id): + return DEFAULT_POLICY_ID + + def rollout(agent, env_name, num_steps, out=None, no_render=True): + policy_agent_mapping = default_policy_agent_mapping + if hasattr(agent, "local_evaluator"): env = agent.local_evaluator.env - multiagent = agent.local_evaluator.multiagent - if multiagent: + multiagent = isinstance(env, MultiAgentEnv) + if agent.local_evaluator.multiagent: policy_agent_mapping = agent.config["multiagent"][ "policy_mapping_fn"] - mapping_cache = {} + policy_map = agent.local_evaluator.policy_map state_init = {p: m.get_initial_state() for p, m in policy_map.items()} use_lstm = {p: len(s) > 0 for p, s in state_init.items()} + action_init = { + p: m.action_space.sample() + for p, m in policy_map.items() + } else: env = gym.make(env_name) multiagent = False @@ -122,39 +143,50 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): rollouts = [] steps = 0 while steps < (num_steps or steps + 1): + mapping_cache = {} # in case policy_agent_mapping is stochastic if out is not None: rollout = [] - state = env.reset() + obs = env.reset() + agent_states = DefaultMapping( + lambda agent_id: state_init[mapping_cache[agent_id]]) + prev_actions = DefaultMapping( + lambda agent_id: action_init[mapping_cache[agent_id]]) + prev_rewards = collections.defaultdict(lambda: 0.) done = False reward_total = 0.0 while not done and steps < (num_steps or steps + 1): + multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs} + action_dict = {} + for agent_id, a_obs in multi_obs.items(): + if a_obs is not None: + policy_id = mapping_cache.setdefault( + agent_id, policy_agent_mapping(agent_id)) + p_use_lstm = use_lstm[policy_id] + if p_use_lstm: + a_action, p_state, _ = agent.compute_action( + a_obs, + state=agent_states[agent_id], + prev_action=prev_actions[agent_id], + prev_reward=prev_rewards[agent_id], + policy_id=policy_id) + agent_states[agent_id] = p_state + else: + a_action = agent.compute_action( + a_obs, + prev_action=prev_actions[agent_id], + prev_reward=prev_rewards[agent_id], + policy_id=policy_id) + action_dict[agent_id] = a_action + prev_actions[agent_id] = a_action + action = action_dict + + action = action if multiagent else action[_DUMMY_AGENT_ID] + next_obs, reward, done, _ = env.step(action) if multiagent: - action_dict = {} - for agent_id in state.keys(): - a_state = state[agent_id] - if a_state is not None: - policy_id = mapping_cache.setdefault( - agent_id, policy_agent_mapping(agent_id)) - p_use_lstm = use_lstm[policy_id] - if p_use_lstm: - a_action, p_state_init, _ = agent.compute_action( - a_state, - state=state_init[policy_id], - policy_id=policy_id) - state_init[policy_id] = p_state_init - else: - a_action = agent.compute_action( - a_state, policy_id=policy_id) - action_dict[agent_id] = a_action - action = action_dict + for agent_id, r in reward.items(): + prev_rewards[agent_id] = r else: - if use_lstm[DEFAULT_POLICY_ID]: - action, state_init, _ = agent.compute_action( - state, state=state_init) - else: - action = agent.compute_action(state) - - next_state, reward, done, _ = env.step(action) + prev_rewards[_DUMMY_AGENT_ID] = reward if multiagent: done = done["__all__"] @@ -164,9 +196,9 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): if not no_render: env.render() if out is not None: - rollout.append([state, action, next_state, reward, done]) + rollout.append([obs, action, next_obs, reward, done]) steps += 1 - state = next_state + obs = next_obs if out is not None: rollouts.append(rollout) print("Episode reward", reward_total)