Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib]: Rename input_evaluation to off_policy_estimation_methods #25107

Merged
merged 6 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/rllib/rllib-env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ You can configure any Trainer to launch a policy server with the following confi
# Use the existing trainer process to run the server.
"num_workers": 0,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
"off_policy_estimation_methods": [],
}

Clients can then connect in either *local* or *remote* inference mode. In local inference mode, copies of the policy are downloaded from the server and cached on the client for a configurable period of time. This allows actions to be computed by the client without requiring a network round trip each time. In remote inference mode, each computed action requires a network call to the server.
Expand Down
8 changes: 4 additions & 4 deletions doc/source/rllib/rllib-offline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Then, we can tell DQN to train using these previously generated experiences with
--env=CartPole-v0 \
--config='{
"input": "/tmp/cartpole-out",
"input_evaluation": [],
"off_policy_estimation_methods": [],
"explore": false}'

.. _is:
Expand All @@ -62,7 +62,7 @@ Then, we can tell DQN to train using these previously generated experiences with
--env=CartPole-v0 \
--config='{
"input": "/tmp/cartpole-out",
"input_evaluation": ["is", "wis"],
"off_policy_estimation_methods": ["is", "wis"],
"exploration_config": {
"type": "SoftQ",
"temperature": 1.0,
Expand Down Expand Up @@ -90,7 +90,7 @@ This example plot shows the Q-value metric in addition to importance sampling (I
print(estimator.estimate(episode))


**Simulation-based estimation:** If true simulation is also possible (i.e., your env supports ``step()``), you can also set ``"input_evaluation": ["simulation"]`` to tell RLlib to run background simulations to estimate current policy performance. The output of these simulations will not be used for learning. Note that in all cases you still need to specify an environment object to define the action and observation spaces. However, you don't need to implement functions like reset() and step().
**Simulation-based estimation:** If true simulation is also possible (i.e., your env supports ``step()``), you can also set ``"off_policy_estimation_methods": ["simulation"]`` to tell RLlib to run background simulations to estimate current policy performance. The output of these simulations will not be used for learning. Note that in all cases you still need to specify an environment object to define the action and observation spaces. However, you don't need to implement functions like reset() and step().

Example: Converting external experiences to batch format
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -270,7 +270,7 @@ You can configure experience input for an agent using the following options:
# - Any subclass of OffPolicyEstimator, e.g.
# ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
# subclass.
"input_evaluation": [
"off_policy_estimation_methods": [
ImportanceSampling,
WeightedImportanceSampling,
],
Expand Down
2 changes: 1 addition & 1 deletion doc/source/rllib/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ The following is a list of the common algorithm hyper-parameters:
# - Any subclass of OffPolicyEstimator, e.g.
# ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
# subclass.
"input_evaluation": [
"off_policy_estimation_methods": [
ImportanceSampling,
WeightedImportanceSampling,
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ marwil-halfcheetahbulletenv-v0:
input: ["~/halfcheetah_expert_sac.zip"]
actions_in_input_normalized: true
# Switch off input evaluation (data does not contain action probs).
input_evaluation: []
off_policy_estimation_methods: []

num_gpus: 1

Expand Down
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ py_test(
"--env", "CartPole-v0",
"--run", "DQN",
"--stop", "'{\"training_iteration\": 1}'",
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"input_evaluation\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"off_policy_estimation_methods\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
]
)

Expand Down
18 changes: 11 additions & 7 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.utils import gym_env_creator
from ray.rllib.evaluation.episode import Episode
from ray.rllib.utils import force_list
from ray.rllib.evaluation.metrics import (
collect_episodes,
collect_metrics,
Expand Down Expand Up @@ -1875,14 +1876,17 @@ def validate_config(self, config: TrainerConfigDict) -> None:
)

# Offline RL settings.
if isinstance(config["input_evaluation"], tuple):
config["input_evaluation"] = list(config["input_evaluation"])
elif not isinstance(config["input_evaluation"], list):
raise ValueError(
"`input_evaluation` must be a list of strings, got {}!".format(
config["input_evaluation"]
)
input_evaluation = config.get("input_evaluation")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

if input_evaluation is not None and input_evaluation is not DEPRECATED_VALUE:
deprecation_warning(
old="config.input_evaluation: {}".format(input_evaluation),
new="config.off_policy_estimation_methods={}".format(input_evaluation),
error=False,
)
config["off_policy_estimation_methods"] = input_evaluation
config["off_policy_estimation_methods"] = force_list(
config["off_policy_estimation_methods"]
)

# Check model config.
# If no preprocessing, propagate into model's config as well
Expand Down
27 changes: 16 additions & 11 deletions rllib/agents/trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.offline.estimators.weighted_importance_sampling import (
WeightedImportanceSampling,
)
from ray.rllib.utils import deep_update, merge_dicts
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.typing import (
EnvConfigDict,
EnvType,
Expand Down Expand Up @@ -170,10 +166,7 @@ def __init__(self, trainer_class=None):
self.input_ = "sampler"
self.input_config = {}
self.actions_in_input_normalized = False
self.input_evaluation = [
ImportanceSampling,
WeightedImportanceSampling,
]
self.off_policy_estimation_methods = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason, why we want to change this default?
I'm worried that some users may rely on this being in their results dict and all of a sudden wonder where this data went and how to switch it back on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. It's not needed here as we do this on the eval worker track. Please ignore my comment above.

Copy link
Author

@ghost ghost May 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm not sure-Like, we can keep it for backwards compatibility, but if the user doesn't explicitly ask for it, should it be enabled? We can discuss this further in the offline eval worker PR.

self.postprocess_inputs = False
self.shuffle_buffer_size = 0
self.output = None
Expand Down Expand Up @@ -236,6 +229,7 @@ def __init__(self, trainer_class=None):
self.prioritized_replay_alpha = DEPRECATED_VALUE
self.prioritized_replay_beta = DEPRECATED_VALUE
self.prioritized_replay_eps = DEPRECATED_VALUE
self.input_evaluation = DEPRECATED_VALUE

def to_dict(self) -> TrainerConfigDict:
"""Converts all settings into a legacy config dict for backward compatibility.
Expand Down Expand Up @@ -862,6 +856,7 @@ def offline_data(
input_config=None,
actions_in_input_normalized=None,
input_evaluation=None,
off_policy_estimation_methods=None,
postprocess_inputs=None,
shuffle_buffer_size=None,
output=None,
Expand Down Expand Up @@ -906,7 +901,8 @@ def offline_data(
are already normalized (between -1.0 and 1.0). This is usually the case
when the offline file has been generated by another RLlib algorithm
(e.g. PPO or SAC), while "normalize_actions" was set to True.
input_evaluation: Specify how to evaluate the current policy.
input_evaluation: DEPRECATED: Use `off_policy_estimation_methods` instead!
off_policy_estimation_methods: Specify how to evaluate the current policy.
This only has an effect when reading offline experiences
("input" is not "sampler").
Available options:
Expand Down Expand Up @@ -945,7 +941,16 @@ def offline_data(
if actions_in_input_normalized is not None:
self.actions_in_input_normalized = actions_in_input_normalized
if input_evaluation is not None:
self.input_evaluation = input_evaluation
deprecation_warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

old="offline_data(input_evaluation={})".format(input_evaluation),
new="offline_data(off_policy_estimation_methods={})".format(
input_evaluation
),
error=True,
)
self.off_policy_estimation_methods = input_evaluation
if off_policy_estimation_methods is not None:
self.off_policy_estimation_methods = off_policy_estimation_methods
if postprocess_inputs is not None:
self.postprocess_inputs = postprocess_inputs
if shuffle_buffer_size is not None:
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, trainer_class=None):

# Changes to Trainer's/SACConfig's default:
# .offline_data()
self.input_evaluation = []
self.off_policy_estimation_methods = []

# .reporting()
self.min_sample_timesteps_per_reporting = 0
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/cql/tests/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_cql_compilation(self):
# RLlib algorithm (e.g. PPO or SAC).
actions_in_input_normalized=False,
# Switch on off-policy evaluation.
input_evaluation=["is"],
off_policy_estimation_methods=["is"],
)
.training(
clip_actions=False,
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/marwil/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, trainer_class=None):
# not important for behavioral cloning.
self.postprocess_inputs = False
# No reward estimation.
self.input_evaluation = []
self.off_policy_estimation_methods = []
# __sphinx_doc_end__
# fmt: on

Expand Down
4 changes: 3 additions & 1 deletion rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def __init__(self, trainer_class=None):
# the same line.
self.input_ = "sampler"
# Use importance sampling estimators for reward.
self.input_evaluation = [ImportanceSampling, WeightedImportanceSampling]
self.off_policy_estimation_methods = [
ImportanceSampling, WeightedImportanceSampling
]
self.postprocess_inputs = True
self.lr = 1e-4
self.train_batch_size = 2000
Expand Down
4 changes: 3 additions & 1 deletion rllib/algorithms/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def test_marwil_cont_actions_from_offline_file(self):
config["evaluation_config"] = {"input": "sampler"}
# Learn from offline data.
config["input"] = [data_file]
config["input_evaluation"] = [] # disable (data has no action-probs)
config[
"off_policy_estimation_methods"
] = [] # disable (data has no action-probs)
num_iterations = 3

# Test for all frameworks.
Expand Down
23 changes: 13 additions & 10 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
from ray.rllib.utils import force_list, merge_dicts, check_env
from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
from ray.rllib.utils.deprecation import Deprecated, deprecation_warning
from ray.rllib.utils.deprecation import (
Deprecated,
deprecation_warning,
)
from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
Expand Down Expand Up @@ -238,7 +241,7 @@ def __init__(
input_creator: Callable[
[IOContext], InputReader
] = lambda ioctx: ioctx.default_sampler_input(),
input_evaluation: List[str] = frozenset([]),
off_policy_estimation_methods: List[str] = frozenset([]),
output_creator: Callable[
[IOContext], OutputWriter
] = lambda ioctx: NoopOutput(),
Expand Down Expand Up @@ -335,8 +338,8 @@ def __init__(
DefaultCallbacks for training/policy/rollout-worker callbacks.
input_creator: Function that returns an InputReader object for
loading previous generated experiences.
input_evaluation: How to evaluate the policy performance. Setting this only
makes sense when the input is reading offline data.
off_policy_estimation_methods: How to evaluate the policy performance.
Setting this only makes sense when the input is reading offline data.
Available options:
- "simulation" (str): Run the environment in the background, but use
this data for evaluation only and not for learning.
Expand Down Expand Up @@ -695,22 +698,22 @@ def wrap(env):
log_dir, policy_config, worker_index, self
)
self.reward_estimators: List[OffPolicyEstimator] = []
for method in input_evaluation:
for method in off_policy_estimation_methods:
if method == "is":
method = ImportanceSampling
deprecation_warning(
old="config.input_evaluation=[is]",
old="config.off_policy_estimation_methods=[is]",
new="from ray.rllib.offline.estimators import "
f"{method.__name__}; config.input_evaluation="
f"{method.__name__}; config.off_policy_estimation_methods="
f"[{method.__name__}]",
error=False,
)
elif method == "wis":
method = WeightedImportanceSampling
deprecation_warning(
old="config.input_evaluation=[wis]",
old="config.off_policy_estimation_methods=[wis]",
new="from ray.rllib.offline.estimators import "
f"{method.__name__}; config.input_evaluation="
f"{method.__name__}; config.off_policy_estimation_methods="
f"[{method.__name__}]",
error=False,
)
Expand Down Expand Up @@ -752,7 +755,7 @@ def wrap(env):
multiple_episodes_in_batch=pack,
normalize_actions=normalize_actions,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation,
blackhole_outputs="simulation" in off_policy_estimation_methods,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn,
Expand Down
6 changes: 3 additions & 3 deletions rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,9 @@ def valid_module(class_path):
)

if config["input"] == "sampler":
input_evaluation = []
off_policy_estimation_methods = []
else:
input_evaluation = config["input_evaluation"]
off_policy_estimation_methods = config["off_policy_estimation_methods"]

# Assert everything is correct in "multiagent" config dict (if given).
ma_policies = config["multiagent"]["policies"]
Expand Down Expand Up @@ -658,7 +658,7 @@ def valid_module(class_path):
log_level=config["log_level"],
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation=input_evaluation,
off_policy_estimation_methods=off_policy_estimation_methods,
output_creator=output_creator,
remote_worker_envs=config["remote_worker_envs"],
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
Expand Down
3 changes: 0 additions & 3 deletions rllib/examples/parallel_evaluation_and_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,6 @@ def on_train_result(self, *, trainer, result, **kwargs):
# Evaluate every other training iteration (together
# with every other call to Trainer.train()).
"evaluation_interval": args.evaluation_interval,
"evaluation_config": {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason, you removed this here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting for future reference, but that line of code wasn't actually doing anything, since eval workers don't read from offline data. Will be fixed in another PR soon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense. Sampler -> No evaluation. Cool, I think that also answers the question on the default value of off_policy_estimation_methods being empty list (I think we can leave this as by default (input=sampler), no results are generate anyways).

"input_evaluation": ["is"],
},
# Run for n episodes/timesteps (properly distribute load amongst
# all eval workers). The longer it takes to evaluate, the more sense
# it makes to use `evaluation_parallel_to_training=True`.
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/serving/cartpole_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _input(ioctx):
# Use n worker processes to listen on different ports.
"num_workers": args.num_workers,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
"off_policy_estimation_methods": [],
# Create a "chatty" client/server or not.
"callbacks": MyCallbacks if args.callbacks_verbose else None,
# DL framework to use.
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/serving/unity3d_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _input(ioctx):
# Use n worker processes to listen on different ports.
"num_workers": args.num_workers,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
"off_policy_estimation_methods": [],
# Other settings.
"train_batch_size": 256,
"rollout_fragment_length": 20,
Expand Down
8 changes: 4 additions & 4 deletions rllib/offline/off_policy_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
if len(keys) > 1:
raise NotImplementedError(
"Off-policy estimation is not implemented for multi-agent. "
"You can set `input_evaluation: []` to resolve this."
"You can set `off_policy_estimation_methods: []` to resolve this."
)
policy = ioctx.worker.get_policy(keys[0])
return cls(policy, gamma)
Expand Down Expand Up @@ -132,8 +132,8 @@ def check_can_estimate_for(self, batch: SampleBatchType) -> None:

if isinstance(batch, MultiAgentBatch):
raise ValueError(
"IS-estimation is not implemented for multi-agent batches. "
"You can set `input_evaluation: []` to resolve this."
"off-policy estimation is not implemented for multi-agent batches. "
"You can set `off_policy_estimation_methods: []` to resolve this."
)

if "action_prob" not in batch:
Expand All @@ -142,7 +142,7 @@ def check_can_estimate_for(self, batch: SampleBatchType) -> None:
"include action probabilities (i.e., the policy is stochastic "
"and emits the 'action_prob' key). For DQN this means using "
"`exploration_config: {type: 'SoftQ'}`. You can also set "
"`input_evaluation: []` to disable estimation."
"`off_policy_estimation_methods: []` to disable estimation."
)

@DeveloperAPI
Expand Down
Loading