Skip to content

Commit

Permalink
[RLlib]: Rename input_evaluation to off_policy_estimation_methods. (
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohan Potdar authored May 27, 2022
1 parent 0bc04f2 commit ab81c8e
Show file tree
Hide file tree
Showing 20 changed files with 74 additions and 61 deletions.
2 changes: 1 addition & 1 deletion doc/source/rllib/rllib-env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ You can configure any Trainer to launch a policy server with the following confi
# Use the existing trainer process to run the server.
"num_workers": 0,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
"off_policy_estimation_methods": [],
}
Clients can then connect in either *local* or *remote* inference mode. In local inference mode, copies of the policy are downloaded from the server and cached on the client for a configurable period of time. This allows actions to be computed by the client without requiring a network round trip each time. In remote inference mode, each computed action requires a network call to the server.
Expand Down
8 changes: 4 additions & 4 deletions doc/source/rllib/rllib-offline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Then, we can tell DQN to train using these previously generated experiences with
--env=CartPole-v0 \
--config='{
"input": "/tmp/cartpole-out",
"input_evaluation": [],
"off_policy_estimation_methods": [],
"explore": false}'
.. _is:
Expand All @@ -62,7 +62,7 @@ Then, we can tell DQN to train using these previously generated experiences with
--env=CartPole-v0 \
--config='{
"input": "/tmp/cartpole-out",
"input_evaluation": ["is", "wis"],
"off_policy_estimation_methods": ["is", "wis"],
"exploration_config": {
"type": "SoftQ",
"temperature": 1.0,
Expand Down Expand Up @@ -90,7 +90,7 @@ This example plot shows the Q-value metric in addition to importance sampling (I
print(estimator.estimate(episode))
**Simulation-based estimation:** If true simulation is also possible (i.e., your env supports ``step()``), you can also set ``"input_evaluation": ["simulation"]`` to tell RLlib to run background simulations to estimate current policy performance. The output of these simulations will not be used for learning. Note that in all cases you still need to specify an environment object to define the action and observation spaces. However, you don't need to implement functions like reset() and step().
**Simulation-based estimation:** If true simulation is also possible (i.e., your env supports ``step()``), you can also set ``"off_policy_estimation_methods": ["simulation"]`` to tell RLlib to run background simulations to estimate current policy performance. The output of these simulations will not be used for learning. Note that in all cases you still need to specify an environment object to define the action and observation spaces. However, you don't need to implement functions like reset() and step().

Example: Converting external experiences to batch format
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -270,7 +270,7 @@ You can configure experience input for an agent using the following options:
# - Any subclass of OffPolicyEstimator, e.g.
# ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
# subclass.
"input_evaluation": [
"off_policy_estimation_methods": [
ImportanceSampling,
WeightedImportanceSampling,
],
Expand Down
2 changes: 1 addition & 1 deletion doc/source/rllib/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ The following is a list of the common algorithm hyper-parameters:
# - Any subclass of OffPolicyEstimator, e.g.
# ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
# subclass.
"input_evaluation": [
"off_policy_estimation_methods": [
ImportanceSampling,
WeightedImportanceSampling,
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ marwil-halfcheetahbulletenv-v0:
input: ["~/halfcheetah_expert_sac.zip"]
actions_in_input_normalized: true
# Switch off input evaluation (data does not contain action probs).
input_evaluation: []
off_policy_estimation_methods: []

num_gpus: 1

Expand Down
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ py_test(
"--env", "CartPole-v0",
"--run", "DQN",
"--stop", "'{\"training_iteration\": 1}'",
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"input_evaluation\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"off_policy_estimation_methods\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
]
)

Expand Down
18 changes: 11 additions & 7 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.utils import _gym_env_creator
from ray.rllib.evaluation.episode import Episode
from ray.rllib.utils import force_list
from ray.rllib.evaluation.metrics import (
collect_episodes,
collect_metrics,
Expand Down Expand Up @@ -1890,14 +1891,17 @@ def validate_config(self, config: TrainerConfigDict) -> None:
)

# Offline RL settings.
if isinstance(config["input_evaluation"], tuple):
config["input_evaluation"] = list(config["input_evaluation"])
elif not isinstance(config["input_evaluation"], list):
raise ValueError(
"`input_evaluation` must be a list of strings, got {}!".format(
config["input_evaluation"]
)
input_evaluation = config.get("input_evaluation")
if input_evaluation is not None and input_evaluation is not DEPRECATED_VALUE:
deprecation_warning(
old="config.input_evaluation: {}".format(input_evaluation),
new="config.off_policy_estimation_methods={}".format(input_evaluation),
error=False,
)
config["off_policy_estimation_methods"] = input_evaluation
config["off_policy_estimation_methods"] = force_list(
config["off_policy_estimation_methods"]
)

# Check model config.
# If no preprocessing, propagate into model's config as well
Expand Down
27 changes: 16 additions & 11 deletions rllib/agents/trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.offline.estimators.weighted_importance_sampling import (
WeightedImportanceSampling,
)
from ray.rllib.utils import deep_update, merge_dicts
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.typing import (
EnvConfigDict,
EnvType,
Expand Down Expand Up @@ -170,10 +166,7 @@ def __init__(self, trainer_class=None):
self.input_ = "sampler"
self.input_config = {}
self.actions_in_input_normalized = False
self.input_evaluation = [
ImportanceSampling,
WeightedImportanceSampling,
]
self.off_policy_estimation_methods = []
self.postprocess_inputs = False
self.shuffle_buffer_size = 0
self.output = None
Expand Down Expand Up @@ -236,6 +229,7 @@ def __init__(self, trainer_class=None):
self.prioritized_replay_alpha = DEPRECATED_VALUE
self.prioritized_replay_beta = DEPRECATED_VALUE
self.prioritized_replay_eps = DEPRECATED_VALUE
self.input_evaluation = DEPRECATED_VALUE

def to_dict(self) -> TrainerConfigDict:
"""Converts all settings into a legacy config dict for backward compatibility.
Expand Down Expand Up @@ -862,6 +856,7 @@ def offline_data(
input_config=None,
actions_in_input_normalized=None,
input_evaluation=None,
off_policy_estimation_methods=None,
postprocess_inputs=None,
shuffle_buffer_size=None,
output=None,
Expand Down Expand Up @@ -906,7 +901,8 @@ def offline_data(
are already normalized (between -1.0 and 1.0). This is usually the case
when the offline file has been generated by another RLlib algorithm
(e.g. PPO or SAC), while "normalize_actions" was set to True.
input_evaluation: Specify how to evaluate the current policy.
input_evaluation: DEPRECATED: Use `off_policy_estimation_methods` instead!
off_policy_estimation_methods: Specify how to evaluate the current policy.
This only has an effect when reading offline experiences
("input" is not "sampler").
Available options:
Expand Down Expand Up @@ -945,7 +941,16 @@ def offline_data(
if actions_in_input_normalized is not None:
self.actions_in_input_normalized = actions_in_input_normalized
if input_evaluation is not None:
self.input_evaluation = input_evaluation
deprecation_warning(
old="offline_data(input_evaluation={})".format(input_evaluation),
new="offline_data(off_policy_estimation_methods={})".format(
input_evaluation
),
error=True,
)
self.off_policy_estimation_methods = input_evaluation
if off_policy_estimation_methods is not None:
self.off_policy_estimation_methods = off_policy_estimation_methods
if postprocess_inputs is not None:
self.postprocess_inputs = postprocess_inputs
if shuffle_buffer_size is not None:
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, trainer_class=None):

# Changes to Trainer's/SACConfig's default:
# .offline_data()
self.input_evaluation = []
self.off_policy_estimation_methods = []

# .reporting()
self.min_sample_timesteps_per_reporting = 0
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/cql/tests/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_cql_compilation(self):
# RLlib algorithm (e.g. PPO or SAC).
actions_in_input_normalized=False,
# Switch on off-policy evaluation.
input_evaluation=["is"],
off_policy_estimation_methods=["is"],
)
.training(
clip_actions=False,
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/marwil/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, trainer_class=None):
# not important for behavioral cloning.
self.postprocess_inputs = False
# No reward estimation.
self.input_evaluation = []
self.off_policy_estimation_methods = []
# __sphinx_doc_end__
# fmt: on

Expand Down
4 changes: 3 additions & 1 deletion rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def __init__(self, trainer_class=None):
# the same line.
self.input_ = "sampler"
# Use importance sampling estimators for reward.
self.input_evaluation = [ImportanceSampling, WeightedImportanceSampling]
self.off_policy_estimation_methods = [
ImportanceSampling, WeightedImportanceSampling
]
self.postprocess_inputs = True
self.lr = 1e-4
self.train_batch_size = 2000
Expand Down
4 changes: 3 additions & 1 deletion rllib/algorithms/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def test_marwil_cont_actions_from_offline_file(self):
config["evaluation_config"] = {"input": "sampler"}
# Learn from offline data.
config["input"] = [data_file]
config["input_evaluation"] = [] # disable (data has no action-probs)
config[
"off_policy_estimation_methods"
] = [] # disable (data has no action-probs)
num_iterations = 3

# Test for all frameworks.
Expand Down
23 changes: 13 additions & 10 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@
from ray.rllib.utils import force_list, merge_dicts, check_env
from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
from ray.rllib.utils.deprecation import Deprecated, deprecation_warning
from ray.rllib.utils.deprecation import (
Deprecated,
deprecation_warning,
)
from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
Expand Down Expand Up @@ -239,7 +242,7 @@ def __init__(
input_creator: Callable[
[IOContext], InputReader
] = lambda ioctx: ioctx.default_sampler_input(),
input_evaluation: List[str] = frozenset([]),
off_policy_estimation_methods: List[str] = frozenset([]),
output_creator: Callable[
[IOContext], OutputWriter
] = lambda ioctx: NoopOutput(),
Expand Down Expand Up @@ -336,8 +339,8 @@ def __init__(
DefaultCallbacks for training/policy/rollout-worker callbacks.
input_creator: Function that returns an InputReader object for
loading previous generated experiences.
input_evaluation: How to evaluate the policy performance. Setting this only
makes sense when the input is reading offline data.
off_policy_estimation_methods: How to evaluate the policy performance.
Setting this only makes sense when the input is reading offline data.
Available options:
- "simulation" (str): Run the environment in the background, but use
this data for evaluation only and not for learning.
Expand Down Expand Up @@ -696,22 +699,22 @@ def wrap(env):
log_dir, policy_config, worker_index, self
)
self.reward_estimators: List[OffPolicyEstimator] = []
for method in input_evaluation:
for method in off_policy_estimation_methods:
if method == "is":
method = ImportanceSampling
deprecation_warning(
old="config.input_evaluation=[is]",
old="config.off_policy_estimation_methods=[is]",
new="from ray.rllib.offline.estimators import "
f"{method.__name__}; config.input_evaluation="
f"{method.__name__}; config.off_policy_estimation_methods="
f"[{method.__name__}]",
error=False,
)
elif method == "wis":
method = WeightedImportanceSampling
deprecation_warning(
old="config.input_evaluation=[wis]",
old="config.off_policy_estimation_methods=[wis]",
new="from ray.rllib.offline.estimators import "
f"{method.__name__}; config.input_evaluation="
f"{method.__name__}; config.off_policy_estimation_methods="
f"[{method.__name__}]",
error=False,
)
Expand Down Expand Up @@ -753,7 +756,7 @@ def wrap(env):
multiple_episodes_in_batch=pack,
normalize_actions=normalize_actions,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation,
blackhole_outputs="simulation" in off_policy_estimation_methods,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
Expand Down
6 changes: 3 additions & 3 deletions rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,9 @@ def valid_module(class_path):
)

if config["input"] == "sampler":
input_evaluation = []
off_policy_estimation_methods = []
else:
input_evaluation = config["input_evaluation"]
off_policy_estimation_methods = config["off_policy_estimation_methods"]

# Assert everything is correct in "multiagent" config dict (if given).
ma_policies = config["multiagent"]["policies"]
Expand Down Expand Up @@ -664,7 +664,7 @@ def valid_module(class_path):
log_level=config["log_level"],
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation=input_evaluation,
off_policy_estimation_methods=off_policy_estimation_methods,
output_creator=output_creator,
remote_worker_envs=config["remote_worker_envs"],
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
Expand Down
3 changes: 0 additions & 3 deletions rllib/examples/parallel_evaluation_and_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,6 @@ def on_train_result(self, *, trainer, result, **kwargs):
# Evaluate every other training iteration (together
# with every other call to Trainer.train()).
"evaluation_interval": args.evaluation_interval,
"evaluation_config": {
"input_evaluation": ["is"],
},
# Run for n episodes/timesteps (properly distribute load amongst
# all eval workers). The longer it takes to evaluate, the more sense
# it makes to use `evaluation_parallel_to_training=True`.
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/serving/cartpole_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _input(ioctx):
# Use n worker processes to listen on different ports.
"num_workers": args.num_workers,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
"off_policy_estimation_methods": [],
# Create a "chatty" client/server or not.
"callbacks": MyCallbacks if args.callbacks_verbose else None,
# DL framework to use.
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/serving/unity3d_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _input(ioctx):
# Use n worker processes to listen on different ports.
"num_workers": args.num_workers,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
"off_policy_estimation_methods": [],
# Other settings.
"train_batch_size": 256,
"rollout_fragment_length": 20,
Expand Down
8 changes: 4 additions & 4 deletions rllib/offline/off_policy_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
if len(keys) > 1:
raise NotImplementedError(
"Off-policy estimation is not implemented for multi-agent. "
"You can set `input_evaluation: []` to resolve this."
"You can set `off_policy_estimation_methods: []` to resolve this."
)
policy = ioctx.worker.get_policy(keys[0])
return cls(policy, gamma)
Expand Down Expand Up @@ -134,8 +134,8 @@ def check_can_estimate_for(self, batch: SampleBatchType) -> None:

if isinstance(batch, MultiAgentBatch):
raise ValueError(
"IS-estimation is not implemented for multi-agent batches. "
"You can set `input_evaluation: []` to resolve this."
"off-policy estimation is not implemented for multi-agent batches. "
"You can set `off_policy_estimation_methods: []` to resolve this."
)

if "action_prob" not in batch:
Expand All @@ -144,7 +144,7 @@ def check_can_estimate_for(self, batch: SampleBatchType) -> None:
"include action probabilities (i.e., the policy is stochastic "
"and emits the 'action_prob' key). For DQN this means using "
"`exploration_config: {type: 'SoftQ'}`. You can also set "
"`input_evaluation: []` to disable estimation."
"`off_policy_estimation_methods: []` to disable estimation."
)

@DeveloperAPI
Expand Down
Loading

0 comments on commit ab81c8e

Please sign in to comment.