Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] ExternalMultiAgentEnv #4200

Merged
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a934614
* start with building towards ExternalMultiAgentEnv
Feb 26, 2019
7731a85
* current issues:
Feb 28, 2019
7ec6caa
Merge branch 'master' into feature/external_multi_agent_env
Feb 28, 2019
5dca2c4
* refix multiagent_two_trainers.py example
ctombumila37 Feb 28, 2019
65d4c2b
Merge branch 'master' into feature/external_multi_agent_env
ctombumila37 Mar 1, 2019
5c1290e
* refactoring:
ctombumila37 Mar 1, 2019
6c7f218
* refactoring:
ctombumila37 Mar 7, 2019
99cd4e4
* fix "reward key error" by relying on observation agent keys
ctombumila37 Mar 18, 2019
669a0d1
Merge branch 'master' into feature/external_multi_agent_env
ctombumila37 Mar 19, 2019
e379088
* make ExternalMultiAgentEnv subclass ExternalEnv
ctombumila37 Mar 27, 2019
f48b374
Merge branch 'master' into feature/external_multi_agent_env
ctombumila37 Mar 27, 2019
d0431ac
* fix precedence of ExternalMultiAgentEnv over ExternalEnv (class hie…
ctombumila37 Mar 27, 2019
294783f
- remove redundant policy_server and policy_client (#4200)
ctombumila37 Mar 28, 2019
d6d22d1
Update external_multi_agent_env.py
ericl Mar 29, 2019
918ed01
Update external_multi_agent_env.py
ericl Mar 29, 2019
c89dbb3
Refactoring (#4200)
ctombumila37 Apr 1, 2019
774f301
Refactoring (#4200)
ctombumila37 Apr 2, 2019
ea1bf37
Merge branch 'master' into feature/external_multi_agent_env
ctombumila37 Apr 2, 2019
5034266
Add override import (#4200)
ctombumila37 Apr 2, 2019
c3f8f19
Merge branch 'master' into feature/external_multi_agent_env
ctombumila37 Apr 4, 2019
08cc22c
[rllib] Add first tests for ExternalMultiAgentEnv (#4200)
ctombumila37 Apr 4, 2019
378d19c
lint
ericl Apr 7, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions python/ray/rllib/env/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import print_function

from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.annotations import override, PublicAPI
Expand Down Expand Up @@ -102,6 +103,11 @@ def to_base_env(env,
make_env=make_env,
existing_envs=[env],
num_envs=num_envs)
elif isinstance(env, ExternalMultiAgentEnv):
if num_envs != 1:
raise ValueError(
"ExternalMultiAgentEnv does not currently support num_envs > 1.")
env = _ExternalEnvToBaseEnv(env, multiagent=True)
elif isinstance(env, ExternalEnv):
if num_envs != 1:
raise ValueError(
Expand Down Expand Up @@ -203,9 +209,10 @@ def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
class _ExternalEnvToBaseEnv(BaseEnv):
"""Internal adapter of ExternalEnv to BaseEnv."""

def __init__(self, external_env, preprocessor=None):
def __init__(self, external_env, preprocessor=None, multiagent=False):
self.external_env = external_env
self.prep = preprocessor
self.multiagent = multiagent
self.action_space = external_env.action_space
if preprocessor:
self.observation_space = preprocessor.observation_space
Expand All @@ -230,16 +237,21 @@ def poll(self):

@override(BaseEnv)
def send_actions(self, action_dict):
for eid, action in action_dict.items():
self.external_env._episodes[eid].action_queue.put(
action[_DUMMY_AGENT_ID])
if self.multiagent:
for env_id, actions in action_dict.items():
self.external_env._episodes[env_id].action_queue.put(actions)
else:
for env_id, action in action_dict.items():
self.external_env._episodes[env_id].action_queue.put(
action[_DUMMY_AGENT_ID])

def _poll(self):
all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
off_policy_actions = {}
for eid, episode in self.external_env._episodes.copy().items():
data = episode.get_data()
if episode.cur_done:
cur_done = episode.cur_done_dict["__all__"] if self.multiagent else episode.cur_done
if cur_done:
del self.external_env._episodes[eid]
if data:
if self.prep:
Expand All @@ -251,11 +263,24 @@ def _poll(self):
all_infos[eid] = data["info"]
if "off_policy_action" in data:
off_policy_actions[eid] = data["off_policy_action"]
return _with_dummy_agent_id(all_obs), \
_with_dummy_agent_id(all_rewards), \
_with_dummy_agent_id(all_dones, "__all__"), \
_with_dummy_agent_id(all_infos), \
_with_dummy_agent_id(off_policy_actions)
if self.multiagent:
# ensure a consistent set of keys
# rely on all_obs having all possible keys for now
for eid, eid_dict in all_obs.items():
for agent_id in eid_dict.keys():
def fix(d, zero_val):
if agent_id not in d[eid]:
d[eid][agent_id] = zero_val
fix(all_rewards, 0.0)
fix(all_dones, False)
fix(all_infos, {})
return all_obs, all_rewards, all_dones, all_infos, off_policy_actions
ctombumila37 marked this conversation as resolved.
Show resolved Hide resolved
else:
return _with_dummy_agent_id(all_obs), \
_with_dummy_agent_id(all_rewards), \
_with_dummy_agent_id(all_dones, "__all__"), \
_with_dummy_agent_id(all_infos), \
_with_dummy_agent_id(off_policy_actions)


class _VectorEnvToBaseEnv(BaseEnv):
Expand Down
77 changes: 55 additions & 22 deletions python/ray/rllib/env/external_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,53 +184,86 @@ def _get(self, episode_id):
class _ExternalEnvEpisode(object):
"""Tracked state for each active episode."""

def __init__(self, episode_id, results_avail_condition, training_enabled):
def __init__(self, episode_id, results_avail_condition, training_enabled, multiagent=False):
self.episode_id = episode_id
self.results_avail_condition = results_avail_condition
self.training_enabled = training_enabled
self.multiagent = multiagent
self.data_queue = queue.Queue()
self.action_queue = queue.Queue()
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
self.cur_done = False
self.cur_info = {}
if multiagent:
self.new_observation_dict = None
self.new_action_dict = None
self.cur_reward_dict = {}
self.cur_done_dict = {"__all__": False}
self.cur_info_dict = {}
else:
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
self.cur_done = False
self.cur_info = {}

def get_data(self):
if self.data_queue.empty():
return None
return self.data_queue.get_nowait()

def log_action(self, observation, action):
self.new_observation = observation
self.new_action = action
if self.multiagent:
self.new_observation_dict = observation
self.new_action_dict = action
else:
self.new_observation = observation
self.new_action = action
self._send()
self.action_queue.get(True, timeout=60.0)

def wait_for_action(self, observation):
self.new_observation = observation
if self.multiagent:
self.new_observation_dict = observation
else:
self.new_observation = observation
self._send()
return self.action_queue.get(True, timeout=60.0)

def done(self, observation):
self.new_observation = observation
self.cur_done = True
if self.multiagent:
self.new_observation_dict = observation
self.cur_done_dict = {"__all__": True}
else:
self.new_observation = observation
self.cur_done = True
self._send()

def _send(self):
item = {
"obs": self.new_observation,
"reward": self.cur_reward,
"done": self.cur_done,
"info": self.cur_info,
}
if self.new_action is not None:
item["off_policy_action"] = self.new_action
if self.multiagent:
item = {
"obs": self.new_observation_dict,
"reward": self.cur_reward_dict,
"done": self.cur_done_dict,
"info": self.cur_info_dict,
}
if self.new_action_dict is not None:
item["off_policy_action"] = self.new_action_dict
self.new_observation_dict = None
self.new_action_dict = None
self.cur_reward_dict = {}
else:
item = {
"obs": self.new_observation,
"reward": self.cur_reward,
"done": self.cur_done,
"info": self.cur_info,
}
if self.new_action is not None:
item["off_policy_action"] = self.new_action
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
if not self.training_enabled:
item["info"]["training_enabled"] = False
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
with self.results_avail_condition:
self.data_queue.put_nowait(item)
self.results_avail_condition.notify()

145 changes: 145 additions & 0 deletions python/ray/rllib/env/external_multi_agent_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six.moves import queue
import threading
import uuid

from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.env.external_env import ExternalEnv, _ExternalEnvEpisode


@PublicAPI
class ExternalMultiAgentEnv(ExternalEnv):
"""This is the multi-agent version of ExternalEnv."""

@PublicAPI
def __init__(self, action_space, observation_space, max_concurrent=100):
"""Initialize a multi-agent external env.

ExternalMultiAgentEnv subclasses must call this during their __init__.

Arguments:
action_space (gym.Space): Action space of the env.
observation_space (gym.Space): Observation space of the env.
max_concurrent (int): Max number of active episodes to allow at
once. Exceeding this limit raises an error.
"""
ExternalEnv.__init__(
self, action_space, observation_space, max_concurrent)

# we require to know all agents' spaces
if isinstance(self.action_space, dict) or isinstance(self.observation_space, dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed sometimes you pass in None for the spaces here -- should that be allowed?

if not (self.action_space.keys() == self.observation_space.keys()):
raise ValueError("Agent ids disagree for action space and obs space dict: {} {}".format(self.action_space.keys(), self.observation_space.keys()))

@PublicAPI
def run(self):
"""Override this to implement the multi-agent run loop.

Your loop should continuously:
1. Call self.start_episode(episode_id)
2. Call self.get_action(episode_id, obs_dict)
-or-
self.log_action(episode_id, obs_dict, action_dict)
3. Call self.log_returns(episode_id, reward_dict)
4. Call self.end_episode(episode_id, obs_dict)
5. Wait if nothing to do.

Multiple episodes may be started at the same time.
"""
raise NotImplementedError

@PublicAPI
@override(ExternalEnv)
def start_episode(self, episode_id=None, training_enabled=True):
if episode_id is None:
episode_id = uuid.uuid4().hex

if episode_id in self._finished:
raise ValueError(
"Episode {} has already completed.".format(episode_id))

if episode_id in self._episodes:
raise ValueError(
"Episode {} is already started".format(episode_id))

self._episodes[episode_id] = _ExternalEnvEpisode(
episode_id, self._results_avail_condition, training_enabled, multiagent=True)

return episode_id

@PublicAPI
@override(ExternalEnv)
def get_action(self, episode_id, observation_dict):
"""Record an observation and get the on-policy action.
observation_dict is expected to contain the observation
of all agents acting in this episode step.

Arguments:
episode_id (str): Episode id returned from start_episode().
observation_dict (dict): Current environment observation.

Returns:
action (dict): Action from the env action space.
"""

episode = self._get(episode_id)
return episode.wait_for_action(observation_dict)

@PublicAPI
@override(ExternalEnv)
def log_action(self, episode_id, observation_dict, action_dict):
"""Record an observation and (off-policy) action taken.

Arguments:
episode_id (str): Episode id returned from start_episode().
observation_dict (dict): Current environment observation.
action_dict (dict): Action for the observation.
"""

episode = self._get(episode_id)
episode.log_action(observation_dict, action_dict)

@PublicAPI
@override(ExternalEnv)
def log_returns(self, episode_id, reward_dict, info_dict=None):
"""Record returns from the environment.

The reward will be attributed to the previous action taken by the
episode. Rewards accumulate until the next action. If no reward is
logged before the next action, a reward of 0.0 is assumed.

Arguments:
episode_id (str): Episode id returned from start_episode().
reward_dict (dict): Reward from the environment agents.
info (dict): Optional info dict.
"""

episode = self._get(episode_id)

# accumulate reward by agent
# for existing agents, we want to add the reward up
for agent, rew in reward_dict.items():
if agent in episode.cur_reward_dict:
episode.cur_reward_dict[agent] += rew
else:
episode.cur_reward_dict[agent] = rew
if info_dict:
episode.cur_info_dict = info_dict or {}

@PublicAPI
@override(ExternalEnv)
def end_episode(self, episode_id, observation_dict):
"""Record the end of an episode.

Arguments:
episode_id (str): Episode id returned from start_episode().
observation_dict (dict): Current environment observation.
"""

episode = self._get(episode_id)
self._finished.add(episode.episode_id)
episode.done(observation_dict)

6 changes: 4 additions & 2 deletions python/ray/rllib/evaluation/policy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
Expand Down Expand Up @@ -308,12 +309,13 @@ def make_env(vector_index):

self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent:
if not (isinstance(self.env, MultiAgentEnv)
if not ((isinstance(self.env, MultiAgentEnv)
or isinstance(self.env, ExternalMultiAgentEnv))
ctombumila37 marked this conversation as resolved.
Show resolved Hide resolved
or isinstance(self.env, BaseEnv)):
raise ValueError(
"Have multiple policy graphs {}, but the env ".format(
self.policy_map) +
"{} is not a subclass of MultiAgentEnv?".format(self.env))
"{} is not a subclass of BaseEnv, MultiAgentEnv or ExternalMultiAgentEnv?".format(self.env))

self.filters = {
policy_id: get_filter(observation_filter,
Expand Down
Loading