Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] annotate public vs developer vs private APIs #3808

Merged
merged 15 commits into from
Jan 24, 2019
9 changes: 9 additions & 0 deletions doc/source/rllib-dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ Development Install

You can develop RLlib locally without needing to compile Ray by using the `setup-rllib-dev.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/setup-rllib-dev.py>`__ script. This sets up links between the ``rllib`` dir in your git repo and the one bundled with the ``ray`` package. When using this script, make sure that your git branch is in sync with the installed Ray binaries (i.e., you are up-to-date on `master <https://github.com/ray-project/ray>`__ and have the latest `wheel <https://ray.readthedocs.io/en/latest/installation.html>`__ installed.)

API Stability
-------------

Objects and methods annotated with ``@PublicAPI`` or ``@DeveloperAPI`` have the following API compatibility guarantees:

.. autofunction:: ray.rllib.utils.annotations.PublicAPI

.. autofunction:: ray.rllib.utils.annotations.DeveloperAPI

Features
--------

Expand Down
4 changes: 2 additions & 2 deletions doc/source/rllib-env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,6 @@ Note that envs can read from different partitions of the logs based on the ``wor
Batch Asynchronous
------------------

The lowest-level "catch-all" environment supported by RLlib is `AsyncVectorEnv <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/async_vector_env.py>`__. AsyncVectorEnv models multiple agents executing asynchronously in multiple environments. A call to ``poll()`` returns observations from ready agents keyed by their environment and agent ids, and actions for those agents can be sent back via ``send_actions()``. This interface can be subclassed directly to support batched simulators such as `ELF <https://github.com/facebookresearch/ELF>`__.
The lowest-level "catch-all" environment supported by RLlib is `BaseEnv <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/base_env.py>`__. BaseEnv models multiple agents executing asynchronously in multiple environments. A call to ``poll()`` returns observations from ready agents keyed by their environment and agent ids, and actions for those agents can be sent back via ``send_actions()``. This interface can be subclassed directly to support batched simulators such as `ELF <https://github.com/facebookresearch/ELF>`__.

Under the hood, all other envs are converted to AsyncVectorEnv by RLlib so that there is a common internal path for policy evaluation.
Under the hood, all other envs are converted to BaseEnv by RLlib so that there is a common internal path for policy evaluation.
2 changes: 1 addition & 1 deletion doc/source/rllib-envs.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions python/ray/rllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.external_env import ExternalEnv
Expand Down Expand Up @@ -47,7 +47,7 @@ def _register_all():
"TFPolicyGraph",
"PolicyEvaluator",
"SampleBatch",
"AsyncVectorEnv",
"BaseEnv",
"MultiAgentEnv",
"VectorEnv",
"ExternalEnv",
Expand Down
15 changes: 14 additions & 1 deletion python/ray/rllib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
Expand Down Expand Up @@ -182,6 +182,7 @@
# yapf: enable


@DeveloperAPI
def with_common_config(extra_config):
"""Returns the given config dict merged with common agent confs."""

Expand All @@ -196,6 +197,7 @@ def with_base_config(base_config, extra_config):
return config


@PublicAPI
class Agent(Trainable):
"""All RLlib agents extend this base class.

Expand All @@ -214,6 +216,7 @@ class Agent(Trainable):
"custom_resources_per_worker"
]

@PublicAPI
def __init__(self, config=None, env=None, logger_creator=None):
"""Initialize an RLLib agent.

Expand Down Expand Up @@ -266,6 +269,7 @@ def default_resource_request(cls, config):
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])

@override(Trainable)
@PublicAPI
def train(self):
"""Overrides super.train to synchronize global vars."""

Expand Down Expand Up @@ -344,11 +348,13 @@ def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path, "rb"))
self.__setstate__(extra_data)

@DeveloperAPI
def _init(self):
"""Subclasses should override this for custom initialization."""

raise NotImplementedError

@PublicAPI
def compute_action(self,
observation,
state=None,
Expand Down Expand Up @@ -404,6 +410,7 @@ def _default_config(self):

raise NotImplementedError

@PublicAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.

Expand All @@ -413,6 +420,7 @@ def get_policy(self, policy_id=DEFAULT_POLICY_ID):

return self.local_evaluator.get_policy(policy_id)

@PublicAPI
def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.

Expand All @@ -422,6 +430,7 @@ def get_weights(self, policies=None):
"""
return self.local_evaluator.get_weights(policies)

@PublicAPI
def set_weights(self, weights):
"""Set policy weights by policy id.

Expand All @@ -430,6 +439,7 @@ def set_weights(self, weights):
"""
self.local_evaluator.set_weights(weights)

@DeveloperAPI
def make_local_evaluator(self, env_creator, policy_graph):
"""Convenience method to return configured local evaluator."""

Expand All @@ -444,6 +454,7 @@ def make_local_evaluator(self, env_creator, policy_graph):
config["local_evaluator_tf_session_args"]
}))

@DeveloperAPI
def make_remote_evaluators(self, env_creator, policy_graph, count):
"""Convenience method to return a number of remote evaluators."""

Expand All @@ -459,6 +470,7 @@ def make_remote_evaluators(self, env_creator, policy_graph, count):
self.config) for i in range(count)
]

@DeveloperAPI
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
"""Export policy model with given policy_id to local directory.

Expand All @@ -474,6 +486,7 @@ def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
"""
self.local_evaluator.export_policy_model(export_dir, policy_id)

@DeveloperAPI
def export_policy_checkpoint(self,
export_dir,
filename_prefix="model",
Expand Down
6 changes: 3 additions & 3 deletions python/ray/rllib/env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.serving_env import ServingEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.env_context import EnvContext

__all__ = [
"AsyncVectorEnv", "MultiAgentEnv", "ExternalEnv", "VectorEnv",
"ServingEnv", "EnvContext"
"BaseEnv", "MultiAgentEnv", "ExternalEnv", "VectorEnv", "ServingEnv",
"EnvContext"
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, PublicAPI


class AsyncVectorEnv(object):
@PublicAPI
class BaseEnv(object):
"""The lowest-level env interface used by RLlib for sampling.

AsyncVectorEnv models multiple agents executing asynchronously in multiple
BaseEnv models multiple agents executing asynchronously in multiple
environments. A call to poll() returns observations from ready agents
keyed by their environment and agent ids, and actions for those agents
can be sent back via send_actions().

All other env types can be adapted to AsyncVectorEnv. RLlib handles these
All other env types can be adapted to BaseEnv. RLlib handles these
conversions internally in PolicyEvaluator, for example:

gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.MultiAgentEnv => rllib.AsyncVectorEnv
rllib.ExternalEnv => rllib.AsyncVectorEnv
gym.Env => rllib.VectorEnv => rllib.BaseEnv
rllib.MultiAgentEnv => rllib.BaseEnv
rllib.ExternalEnv => rllib.BaseEnv

Attributes:
action_space (gym.Space): Action space. This must be defined for
Expand All @@ -30,7 +31,7 @@ class AsyncVectorEnv(object):
for single-agent envs. Multi-agent envs can set this to None.

Examples:
>>> env = MyAsyncVectorEnv()
>>> env = MyBaseEnv()
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
>>> print(obs)
{
Expand Down Expand Up @@ -65,26 +66,27 @@ class AsyncVectorEnv(object):
"""

@staticmethod
def wrap_async(env, make_env=None, num_envs=1):
def to_base_env(env, make_env=None, num_envs=1):
"""Wraps any env type as needed to expose the async interface."""
if not isinstance(env, AsyncVectorEnv):
if not isinstance(env, BaseEnv):
if isinstance(env, MultiAgentEnv):
env = _MultiAgentEnvToAsync(
env = _MultiAgentEnvToBaseEnv(
make_env=make_env, existing_envs=[env], num_envs=num_envs)
elif isinstance(env, ExternalEnv):
if num_envs != 1:
raise ValueError(
"ExternalEnv does not currently support num_envs > 1.")
env = _ExternalEnvToAsync(env)
env = _ExternalEnvToBaseEnv(env)
elif isinstance(env, VectorEnv):
env = _VectorEnvToAsync(env)
env = _VectorEnvToBaseEnv(env)
else:
env = VectorEnv.wrap(
make_env=make_env, existing_envs=[env], num_envs=num_envs)
env = _VectorEnvToAsync(env)
assert isinstance(env, AsyncVectorEnv)
env = _VectorEnvToBaseEnv(env)
assert isinstance(env, BaseEnv)
return env

@PublicAPI
def poll(self):
"""Returns observations from ready agents.

Expand All @@ -107,6 +109,7 @@ def poll(self):
"""
raise NotImplementedError

@PublicAPI
def send_actions(self, action_dict):
"""Called to send actions back to running agents in this env.

Expand All @@ -118,6 +121,7 @@ def send_actions(self, action_dict):
"""
raise NotImplementedError

@PublicAPI
def try_reset(self, env_id):
"""Attempt to reset the env with the given id.

Expand All @@ -129,6 +133,7 @@ def try_reset(self, env_id):
"""
return None

@PublicAPI
def get_unwrapped(self):
"""Return a reference to the underlying gym envs, if any.

Expand All @@ -146,8 +151,8 @@ def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()}


class _ExternalEnvToAsync(AsyncVectorEnv):
"""Internal adapter of ExternalEnv to AsyncVectorEnv."""
class _ExternalEnvToBaseEnv(BaseEnv):
"""Internal adapter of ExternalEnv to BaseEnv."""

def __init__(self, external_env, preprocessor=None):
self.external_env = external_env
Expand All @@ -159,7 +164,7 @@ def __init__(self, external_env, preprocessor=None):
self.observation_space = external_env.observation_space
external_env.start()

@override(AsyncVectorEnv)
@override(BaseEnv)
def poll(self):
with self.external_env._results_avail_condition:
results = self._poll()
Expand All @@ -174,7 +179,7 @@ def poll(self):
"ExternalEnv was created with max_concurrent={}".format(limit))
return results

@override(AsyncVectorEnv)
@override(BaseEnv)
def send_actions(self, action_dict):
for eid, action in action_dict.items():
self.external_env._episodes[eid].action_queue.put(
Expand Down Expand Up @@ -204,8 +209,8 @@ def _poll(self):
_with_dummy_agent_id(off_policy_actions)


class _VectorEnvToAsync(AsyncVectorEnv):
"""Internal adapter of VectorEnv to AsyncVectorEnv.
class _VectorEnvToBaseEnv(BaseEnv):
"""Internal adapter of VectorEnv to BaseEnv.

We assume the caller will always send the full vector of actions in each
call to send_actions(), and that they call reset_at() on all completed
Expand All @@ -222,7 +227,7 @@ def __init__(self, vector_env):
self.cur_dones = [False for _ in range(self.num_envs)]
self.cur_infos = [None for _ in range(self.num_envs)]

@override(AsyncVectorEnv)
@override(BaseEnv)
def poll(self):
if self.new_obs is None:
self.new_obs = self.vector_env.vector_reset()
Expand All @@ -239,25 +244,25 @@ def poll(self):
_with_dummy_agent_id(dones, "__all__"), \
_with_dummy_agent_id(infos), {}

@override(AsyncVectorEnv)
@override(BaseEnv)
def send_actions(self, action_dict):
action_vector = [None] * self.num_envs
for i in range(self.num_envs):
action_vector[i] = action_dict[i][_DUMMY_AGENT_ID]
self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \
self.vector_env.vector_step(action_vector)

@override(AsyncVectorEnv)
@override(BaseEnv)
def try_reset(self, env_id):
return {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)}

@override(AsyncVectorEnv)
@override(BaseEnv)
def get_unwrapped(self):
return self.vector_env.get_unwrapped()


class _MultiAgentEnvToAsync(AsyncVectorEnv):
"""Internal adapter of MultiAgentEnv to AsyncVectorEnv.
class _MultiAgentEnvToBaseEnv(BaseEnv):
"""Internal adapter of MultiAgentEnv to BaseEnv.

This also supports vectorization if num_envs > 1.
"""
Expand All @@ -282,14 +287,14 @@ def __init__(self, make_env, existing_envs, num_envs):
assert isinstance(env, MultiAgentEnv)
self.env_states = [_MultiAgentEnvState(env) for env in self.envs]

@override(AsyncVectorEnv)
@override(BaseEnv)
def poll(self):
obs, rewards, dones, infos = {}, {}, {}, {}
for i, env_state in enumerate(self.env_states):
obs[i], rewards[i], dones[i], infos[i] = env_state.poll()
return obs, rewards, dones, infos, {}

@override(AsyncVectorEnv)
@override(BaseEnv)
def send_actions(self, action_dict):
for env_id, agent_dict in action_dict.items():
if env_id in self.dones:
Expand All @@ -311,15 +316,15 @@ def send_actions(self, action_dict):
self.dones.add(env_id)
self.env_states[env_id].observe(obs, rewards, dones, infos)

@override(AsyncVectorEnv)
@override(BaseEnv)
def try_reset(self, env_id):
obs = self.env_states[env_id].reset()
assert isinstance(obs, dict), "Not a multi-agent obs"
if obs is not None and env_id in self.dones:
self.dones.remove(env_id)
return obs

@override(AsyncVectorEnv)
@override(BaseEnv)
def get_unwrapped(self):
return [state.env for state in self.env_states]

Expand Down
3 changes: 3 additions & 0 deletions python/ray/rllib/env/env_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from __future__ import division
from __future__ import print_function

from ray.rllib.utils.annotations import PublicAPI


@PublicAPI
class EnvContext(dict):
"""Wraps env configurations to include extra rllib metadata.

Expand Down
Loading