Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib]: Add Off-Policy Estimation docs #26809

Merged
merged 33 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5ecac3d
wip
Rohan138 Jul 19, 2022
b715da1
wip
Rohan138 Jul 19, 2022
b555005
wip
Rohan138 Jul 19, 2022
c40fc8f
Merge branch 'master' of https://github.com/ray-project/ray into ope-…
Rohan138 Jul 20, 2022
a7c752e
wip
Rohan138 Jul 20, 2022
3bbe15c
wip
Rohan138 Jul 20, 2022
74030d2
wip
Rohan138 Jul 20, 2022
1631689
wip
Rohan138 Jul 20, 2022
136aaaf
Merge branch 'master' of https://github.com/ray-project/ray into ope-…
Rohan138 Jul 20, 2022
e175be3
wip
Rohan138 Jul 20, 2022
e0b90ef
wip
Rohan138 Jul 20, 2022
e4b2216
wip
Rohan138 Jul 20, 2022
6215b1a
wip
Rohan138 Jul 20, 2022
42fbd55
wip
Rohan138 Jul 20, 2022
d565123
wip
Rohan138 Jul 20, 2022
c3b29f8
wip
Rohan138 Jul 20, 2022
8abb341
wip
Rohan138 Jul 20, 2022
cbfb3bd
wip
Rohan138 Jul 20, 2022
a6c1e0b
wip
Rohan138 Jul 20, 2022
086db2f
Merge branch 'master' into ope-docs
Rohan138 Jul 21, 2022
6dbad71
Merge branch 'master' of https://github.com/ray-project/ray into ope-…
Rohan138 Jul 21, 2022
7142b46
updated the ope docs
kouroshHakha Jul 21, 2022
ec4a172
Merge branch 'ope-docs' of github.com:Rohan138/ray into ope-docs
kouroshHakha Jul 21, 2022
c261dbd
Minor typos and nits
Rohan138 Jul 22, 2022
fb70f92
Fix RLlib
Rohan138 Jul 22, 2022
64d528a
Add deprecation notice to MARWIL
Rohan138 Jul 22, 2022
4906f99
Fix MARWIL
Rohan138 Jul 22, 2022
14d73bd
Simplify fix
Rohan138 Jul 22, 2022
3e2ea18
Simplify fix
Rohan138 Jul 22, 2022
71038b2
Simplify fix
Rohan138 Jul 22, 2022
088b7b7
lint
Rohan138 Jul 22, 2022
309463d
Minor fix
Rohan138 Jul 23, 2022
f3f4b42
Minor fix
Rohan138 Jul 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed doc/source/rllib/images/offline-q.png
Binary file not shown.
Binary file added doc/source/rllib/images/rllib-offline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
159 changes: 116 additions & 43 deletions doc/source/rllib/rllib-offline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ RLlib's offline dataset APIs enable working with experiences read from offline s
RLlib represents trajectory sequences (i.e., ``(s, a, r, s', ...)`` tuples) with `SampleBatch <https://github.com/ray-project/ray/blob/master/rllib/policy/sample_batch.py>`__ objects. Using a batch format enables efficient encoding and compression of experiences. During online training, RLlib uses `policy evaluation <rllib-concepts.html#policy-evaluation>`__ actors to generate batches of experiences in parallel using the current policy. RLlib also uses this same batch format for reading and writing experiences to offline storage.

Example: Training on previously saved experiences
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-------------------------------------------------

.. note::

Expand All @@ -23,7 +23,7 @@ In this example, we will save batches of experiences generated during online tra

.. code-block:: bash

$ rllib train
$ rllib train \
--run=PG \
--env=CartPole-v0 \
--config='{"output": "/tmp/cartpole-out", "output_max_file_size": 5000000}' \
Expand All @@ -50,53 +50,130 @@ Then, we can tell DQN to train using these previously generated experiences with
"input": "/tmp/cartpole-out",
"explore": false}'

.. _is:
Off-Policy Estimation (OPE)
---------------------------

In practice, when we use offline data for training, it is usually not straightforward to evaluate the trained policies using a simulator similar to how it is done in online RL. For example, in recommeder systems it is often the case that rolling out the policy in real-world can jeopardize your business use-case if the trained policy is not performant enough (e.g. by causing churn on your customers). For these situations we can use `off-policy estimation <https://arxiv.org/abs/1911.06854>`__ methods which help avoid the risk of evaluating your policy in real-world. RLlib provides some basic APIs for off-policy evaluation with an option to use a simulator instead (if it is available).

With RLlib evaluation framework you can:

- Evaluate policies on a simulated environement, if available, using ``evaluation_config["input"] = "sampler"``. You can then monitor your policy's performance on tensorboard as it is getting trained (by using ``tensorboard --logdir=~/ray_results``).

- Use RLlib's off-policy estimation methods, which estimate the policy's performance on a separate offline dataset. To be able to use this feature, the evaluation dataset should contain ``action_prob`` key that represents the action probability distribution of the collected data so that we can do counterfactual evaluation.

RLlib supports the following off-policy estimators:

- `Importance Sampling (IS) <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/importance_sampling.py>`__
- `Weighted Importance Sampling (WIS) <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/weighted_importance_sampling.py>`__
- `Direct Method (DM) <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/direct_method.py>`__
- `Doubly Robust (DR) <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/doubly_robust.py>`__

**Off-policy estimation:** Since the input experiences are not from running simulations, RLlib cannot report the true policy performance during training. However, you can use ``tensorboard --logdir=~/ray_results`` to monitor training progress via other metrics such as estimated Q-value. Alternatively, `off-policy estimation <https://arxiv.org/pdf/1511.03722.pdf>`__ can be used, which requires both the source and target action probabilities to be available (i.e., the ``action_prob`` batch key). For DQN, this means enabling soft Q learning so that actions are sampled from a probability distribution:
IS and WIS compute the ratio between the action probabilities under the behavior policy (from the dataset) and the target policy (the policy under evaluation), and use this ratio to estimate the policy's return. More details on this can be found in their respective papers.

DM and DR train a Q-model to compute the estimated return. By default, RLlib uses `Fitted-Q Evaluation (FQE) <https://arxiv.org/abs/1911.06854>`__ to train the Q-model. See `fqe_torch_model.py <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/fqe_torch_model.py>`__ for more details.

.. note:: For a contextual bandit dataset, the ``dones`` key should always be set to ``True``. In this case, FQE reduces to fitting a reward model to the data.

RLlib's OPE estimators output six metrics:

- ``v_behavior``: The discounted sum over rewards in the offline episode, averaged over episodes in the batch.
- ``v_behavior_std``: The standard deviation corresponding to v_behavior.
- ``v_target``: The OPE's estimated discounted return for the target policy, averaged over episodes in the batch.
- ``v_target_std``: The standard deviation corresponding to v_target.
- ``v_gain``: ``v_target / max(v_behavior, 1e-8)``, averaged over episodes in the batch. ``v_gain > 1.0`` indicates that the policy is better than the policy that generated the behavior data.
- ``v_gain_std``: The standard deviation corresponding to v_gain.

As an example, we generate an evaluation dataset for off-policy estimation:

.. code-block:: bash

$ rllib train \
--run=DQN \
--run=PG \
--env=CartPole-v0 \
--config='{
"input": "/tmp/cartpole-out",
"off_policy_estimation_methods": {
"is": {
"type": "ray.rllib.offline.estimators.ImportanceSampling",
--config='{"output": "/tmp/cartpole-eval", "output_max_file_size": 5000000}' \
--stop='{"timesteps_total": 10000}'

.. hint:: You should use separate datasets for algorithm training and OPE, as shown here.

We can now train a DQN algorithm offline and evaluate it using OPE:

.. code-block:: python

from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.offline.estimators import (
ImportanceSampling,
WeightedImportanceSampling,
DirectMethod,
DoublyRobust,
)
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel

config = (
DQNConfig()
.environment(env="CartPole-v0")
.framework("torch")
.offline_data(input_="/tmp/cartpole-out")
.evaluation(
evaluation_interval=1,
evaluation_duration=10,
evaluation_num_workers=1,
evaluation_duration_unit="episodes",
evaluation_config={"input": "/tmp/cartpole-eval"},
off_policy_estimation_methods={
"is": {"type": ImportanceSampling},
"wis": {"type": WeightedImportanceSampling},
"dm_fqe": {
"type": DirectMethod,
"q_model_config": {"type": FQETorchModel, "tau": 0.05},
},
"dr_fqe": {
"type": DoublyRobust,
"q_model_config": {"type": FQETorchModel, "tau": 0.05},
},
"wis": {
"type": "ray.rllib.offline.estimators.WeightedImportanceSampling",
}
},
"exploration_config": {
"type": "SoftQ",
"temperature": 1.0,
}'
)
)

This example plot shows the Q-value metric in addition to importance sampling (IS) and weighted importance sampling (WIS) gain estimates (>1.0 means there is an estimated improvement over the original policy):
algo = config.build()
for _ in range(100):
algo.train()

.. image:: images/offline-q.png
.. image:: images/rllib-offline.png

**Estimator Python API:** For greater control over the evaluation process, you can create off-policy estimators in your Python code and call ``estimator.estimate(episode_batch)`` to perform counterfactual estimation as needed. The estimators take in a policy object and gamma value for the environment:
**Estimator Python API:** For greater control over the evaluation process, you can create off-policy estimators in your Python code and call ``estimator.train(batch)`` to perform any neccessary training and ``estimator.estimate(batch)`` to perform counterfactual estimation. The estimators take in an RLLib Policy object and gamma value for the environment, along with additional estimator-specific arguments (e.g. ``q_model_config`` for DM and DR). You can take a look at the example config parameters of the q_model_config `here <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/fqe_torch_model.py>`__. You can also write your own off-policy estimator by subclassing from the `OffPolicyEstimator <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/off_policy_estimator.py>`__ base class.

.. code-block:: python

algo = DQN(...)
... # train policy offline

from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator

estimator = WeightedImportanceSamplingEstimator(algo.get_policy(), gamma=0.99)
reader = JsonReader("/path/to/data")
for _ in range(1000):
from ray.rllib.offline.estimators import DoublyRobust
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel

estimator = DoublyRobust(
policy=algo.get_policy(),
gamma=0.99,
q_model_config={"type": FQETorchModel, "n_iters": 160},
)

# Train estimator's Q-model; only required for DM and DR estimators
reader = JsonReader("/tmp/cartpole-out")
for _ in range(100):
batch = reader.next()
for episode in batch.split_by_episode():
print(estimator.estimate(episode))
print(estimator.train(batch))
# {'loss': ...}

reader = JsonReader("/tmp/cartpole-eval")
# Compute off-policy estimates
for _ in range(100):
batch = reader.next()
print(estimator.estimate(batch))
# {'v_behavior': ..., 'v_target': ..., 'v_gain': ...,
# 'v_behavior_std': ..., 'v_target_std': ..., 'v_gain_std': ...}

Example: Converting external experiences to batch format
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--------------------------------------------------------

When the env does not support simulation (e.g., it is a web application), it is necessary to generate the ``*.json`` experience batch files outside of RLlib. This can be done by using the `JsonWriter <https://github.com/ray-project/ray/blob/master/rllib/offline/json_writer.py>`__ class to write out batches.
This `runnable example <https://github.com/ray-project/ray/blob/master/rllib/examples/saving_experiences.py>`__ shows how to generate and save experience batches for CartPole-v0 to disk:
Expand All @@ -107,7 +184,7 @@ This `runnable example <https://github.com/ray-project/ray/blob/master/rllib/exa
:end-before: __sphinx_doc_end__

On-policy algorithms and experience postprocessing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
----------------------------------------------------

RLlib assumes that input batches are of
`postprocessed experiences <https://github.com/ray-project/ray/blob/master/rllib/policy/policy.py#L434>`__.
Expand All @@ -121,7 +198,7 @@ However, for on-policy algorithms like PPO, you'll need to pass in the extra val
Note that for on-policy algorithms, you'll also have to throw away experiences generated by prior versions of the policy. This greatly reduces sample efficiency, which is typically undesirable for offline training, but can make sense for certain applications.

Mixing simulation and offline data
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-----------------------------------

RLlib supports multiplexing inputs from multiple input sources, including simulation. For example, in the following example we read 40% of our experiences from ``/tmp/cartpole-out``, 30% from ``hdfs:/archive/cartpole``, and the last 30% is produced via policy evaluation. Input sources are multiplexed using `np.random.choice <https://docs.scipy.org/doc/numpy-1.15.0/reference/generated/numpy.random.choice.html>`__:

Expand All @@ -139,12 +216,12 @@ RLlib supports multiplexing inputs from multiple input sources, including simula
"explore": false}'

Scaling I/O throughput
~~~~~~~~~~~~~~~~~~~~~~
-----------------------

Similar to scaling online training, you can scale offline I/O throughput by increasing the number of RLlib workers via the ``num_workers`` config. Each worker accesses offline storage independently in parallel, for linear scaling of I/O throughput. Within each read worker, files are chosen in random order for reads, but file contents are read sequentially.

Ray Dataset Integration
~~~~~~~~~~~~~~~~~~~~~~~
--------------------------

RLlib has experimental support for reading/writing training samples from/to large offline datasets using
`Ray Dataset <https://docs.ray.io/en/latest/data/dataset.html>`__.
Expand Down Expand Up @@ -189,15 +266,15 @@ To write sample data to JSON or Parquet files using Dataset, specify output and
}

Writing Environment Data
~~~~~~~~~~~~~~~~~~~~~~~~
--------------------------

To include environment data in the training sample datasets you can use the optional
``store_infos`` parameter that is part of the ``output_config`` dictionary. This parameter
ensures that the ``infos`` dictionary, as returned by the RL environment, is included in the output files.

Note 1: It is the responsibility of the user to ensure that the content of ``infos`` can be serialized
to file.
Note 2: This setting is only relevant for the TensorFlow based agents, for PyTorch agents the ``infos`` data is always stored.
.. note:: It is the responsibility of the user to ensure that the content of ``infos`` can be serialized to file.

.. note:: This setting is only relevant for the TensorFlow based agents, for PyTorch agents the ``infos`` data is always stored.

To write the ``infos`` data to JSON or Parquet files using Dataset, specify output and output_config keys like the following:

Expand Down Expand Up @@ -279,12 +356,8 @@ You can configure experience input for an agent using the following options:
# ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
# subclass.
"off_policy_estimation_methods": {
"is": {
"type": ImportanceSampling,
},
"wis": {
"type": WeightedImportanceSampling,
}
"is": {"type": ImportanceSampling},
"wis": {"type": WeightedImportanceSampling}
},
# Whether to run postprocess_trajectory() on the trajectory fragments from
# offline inputs. Note that postprocessing will be done using the *current*
Expand All @@ -303,7 +376,7 @@ The interface for a custom input reader is as follows:
:noindex:

Example Custom Input API
~~~~~~~~~~~~~~~~~~~~~~~~
-----------------------

You can create a custom input reader like the following:

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self, algo_class=None):
# the same line.
self.input_ = "sampler"
# Use importance sampling estimators for reward.
self.evaluation_config["off_policy_estimation_methods"] = {
self.off_policy_estimation_methods = {
"is": {"type": ImportanceSampling},
"wis": {"type": WeightedImportanceSampling},
}
Expand Down
8 changes: 5 additions & 3 deletions rllib/offline/estimators/direct_method.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Dict, Any
from typing import Dict, Any, Optional
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.framework import try_import_torch
Expand Down Expand Up @@ -35,7 +36,7 @@ def __init__(
self,
policy: Policy,
gamma: float,
q_model_config: Dict = None,
q_model_config: Optional[Dict] = None,
):
"""Initializes a Direct Method OPE Estimator.

Expand All @@ -55,7 +56,8 @@ def __init__(
), "DirectMethod estimator only works with torch!"
super().__init__(policy, gamma)

model_cls = q_model_config.pop("type")
q_model_config = q_model_config or {}
model_cls = q_model_config.pop("type", FQETorchModel)
self.model = model_cls(
policy=policy,
gamma=gamma,
Expand Down
8 changes: 5 additions & 3 deletions rllib/offline/estimators/doubly_robust.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Any
from typing import Dict, Any, Optional
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.framework import try_import_torch
Expand All @@ -9,6 +9,7 @@
from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict

from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel

torch, nn = try_import_torch()

Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(
self,
policy: Policy,
gamma: float,
q_model_config: Dict = None,
q_model_config: Optional[Dict] = None,
):
"""Initializes a Doubly Robust OPE Estimator.

Expand All @@ -63,7 +64,8 @@ def __init__(
"""

super().__init__(policy, gamma)
model_cls = q_model_config.pop("type")
q_model_config = q_model_config or {}
model_cls = q_model_config.pop("type", FQETorchModel)

self.model = model_cls(
policy=policy,
Expand Down
6 changes: 3 additions & 3 deletions rllib/offline/estimators/fqe_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@DeveloperAPI
class FQETorchModel:
"""Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
https://arxiv.org/pdf/1911.06854.pdf
https://arxiv.org/abs/1911.06854
"""

def __init__(
Expand All @@ -44,9 +44,9 @@ def __init__(
"vf_share_layers": True,
},
n_iters: Number of gradient steps to run on batch, defaults to 1
lr: Learning rate for Q-model optimizer
lr: Learning rate for Adam optimizer
delta: Early stopping threshold if the mean loss < delta
clip_grad_norm: Clip gradients to this maximum value
clip_grad_norm: Clip loss gradients to this maximum value
minibatch_size: Minibatch size for training Q-function;
if None, train on the whole batch
tau: Polyak averaging factor for target Q-function
Expand Down