Skip to content

Commit

Permalink
[RLlib] Eps greedy ope (#28837)
Browse files Browse the repository at this point in the history
* 1. Introduced new abstraction: OfflineEvaluator that is the parent of OPE and feature importance
2. introduced estimate_multi_step vs. estimate_single_step

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* algorithm ope evaluation is now able to skip split_by_episode

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* lint

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* lint

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* fixed some unittests

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* added eps greedy exploration to ope methods

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* wip

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* lint

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* wip

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* wip

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* fixed dm and dr variance issues

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* lint

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* cleaned up the inheritance

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* lint

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* lint

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* fixed test

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* nit

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* fixed nits

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* fixed the typos

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* nit

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* wip

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

* wip

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
  • Loading branch information
kouroshHakha authored Sep 29, 2022
1 parent 6dbc116 commit c1e0d39
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 38 deletions.
16 changes: 10 additions & 6 deletions rllib/offline/estimators/direct_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,26 @@ def __init__(
self,
policy: Policy,
gamma: float,
epsilon_greedy: float = 0.0,
q_model_config: Optional[Dict] = None,
):
"""Initializes a Direct Method OPE Estimator.
Args:
policy: Policy to evaluate.
gamma: Discount factor of the environment.
epsilon_greedy: The probability by which we act acording to a fully random
policy during deployment. With 1-epsilon_greedy we act according the
target policy.
q_model_config: Arguments to specify the Q-model. Must specify
a `type` key pointing to the Q-model class.
This Q-model is trained in the train() method and is used
to compute the state-value estimates for the DirectMethod estimator.
It must implement `train` and `estimate_v`.
TODO (Rohan138): Unify this with RLModule API.
a `type` key pointing to the Q-model class.
This Q-model is trained in the train() method and is used
to compute the state-value estimates for the DirectMethod estimator.
It must implement `train` and `estimate_v`.
TODO (Rohan138): Unify this with RLModule API.
"""

super().__init__(policy, gamma)
super().__init__(policy, gamma, epsilon_greedy)

q_model_config = q_model_config or {}
model_cls = q_model_config.pop("type", FQETorchModel)
Expand Down
25 changes: 13 additions & 12 deletions rllib/offline/estimators/doubly_robust.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict

from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
Expand Down Expand Up @@ -46,23 +45,27 @@ def __init__(
self,
policy: Policy,
gamma: float,
epsilon_greedy: float = 0.0,
q_model_config: Optional[Dict] = None,
):
"""Initializes a Doubly Robust OPE Estimator.
Args:
policy: Policy to evaluate.
gamma: Discount factor of the environment.
epsilon_greedy: The probability by which we act acording to a fully random
policy during deployment. With 1-epsilon_greedy we act
according the target policy.
q_model_config: Arguments to specify the Q-model. Must specify
a `type` key pointing to the Q-model class.
This Q-model is trained in the train() method and is used
to compute the state-value and Q-value estimates
for the DoublyRobust estimator.
It must implement `train`, `estimate_q`, and `estimate_v`.
TODO (Rohan138): Unify this with RLModule API.
a `type` key pointing to the Q-model class.
This Q-model is trained in the train() method and is used
to compute the state-value and Q-value estimates
for the DoublyRobust estimator.
It must implement `train`, `estimate_q`, and `estimate_v`.
TODO (Rohan138): Unify this with RLModule API.
"""

super().__init__(policy, gamma)
super().__init__(policy, gamma, epsilon_greedy)
q_model_config = q_model_config or {}
model_cls = q_model_config.pop("type", FQETorchModel)

Expand All @@ -83,8 +86,7 @@ def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]:
estimates_per_epsiode = {}

rewards, old_prob = episode["rewards"], episode["action_prob"]
log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, episode)
new_prob = np.exp(convert_to_numpy(log_likelihoods))
new_prob = self.compute_action_probs(episode)

v_behavior = 0.0
v_target = 0.0
Expand Down Expand Up @@ -113,8 +115,7 @@ def estimate_on_single_step_samples(
estimates_per_epsiode = {}

rewards, old_prob = batch["rewards"], batch["action_prob"]
log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch)
new_prob = np.exp(convert_to_numpy(log_likelihoods))
new_prob = self.compute_action_probs(batch)

q_values = self.model.estimate_q(batch)
q_values = convert_to_numpy(q_values)
Expand Down
9 changes: 2 additions & 7 deletions rllib/offline/estimators/importance_sampling.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict
from typing import Dict, List
import numpy as np


@DeveloperAPI
Expand All @@ -28,8 +25,7 @@ def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, float]:
estimates_per_epsiode = {}

rewards, old_prob = episode["rewards"], episode["action_prob"]
log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, episode)
new_prob = np.exp(convert_to_numpy(log_likelihoods))
new_prob = self.compute_action_probs(episode)

# calculate importance ratios
p = []
Expand Down Expand Up @@ -59,8 +55,7 @@ def estimate_on_single_step_samples(
estimates_per_epsiode = {}

rewards, old_prob = batch["rewards"], batch["action_prob"]
log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch)
new_prob = np.exp(convert_to_numpy(log_likelihoods))
new_prob = self.compute_action_probs(batch)

weights = new_prob / old_prob
v_behavior = rewards
Expand Down
30 changes: 29 additions & 1 deletion rllib/offline/estimators/off_policy_estimator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gym
import numpy as np
import tree
from typing import Dict, Any, List
Expand All @@ -12,6 +13,7 @@
from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict
from ray.rllib.utils.annotations import (
DeveloperAPI,
ExperimentalAPI,
OverrideToImplementCustomLogic,
)
from ray.rllib.utils.deprecation import Deprecated
Expand All @@ -27,15 +29,25 @@ class OffPolicyEstimator(OfflineEvaluator):
"""Interface for an off policy estimator for counterfactual evaluation."""

@DeveloperAPI
def __init__(self, policy: Policy, gamma: float = 0.0):
def __init__(
self,
policy: Policy,
gamma: float = 0.0,
epsilon_greedy: float = 0.0,
):
"""Initializes an OffPolicyEstimator instance.
Args:
policy: Policy to evaluate.
gamma: Discount factor of the environment.
epsilon_greedy: The probability by which we act acording to a fully random
policy during deployment. With 1-epsilon_greedy we act according the target
policy.
# TODO (kourosh): convert the input parameters to a config dict.
"""
self.policy = policy
self.gamma = gamma
self.epsilon_greedy = epsilon_greedy

@DeveloperAPI
def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]:
Expand Down Expand Up @@ -228,6 +240,22 @@ def check_action_prob_in_batch(self, batch: SampleBatchType) -> None:
"`off_policy_estimation_methods: {}` to disable estimation."
)

@ExperimentalAPI
def compute_action_probs(self, batch: SampleBatch):
log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch)
new_prob = np.exp(convert_to_numpy(log_likelihoods))

if self.epsilon_greedy > 0.0:
if not isinstance(self.policy.action_space, gym.spaces.Discrete):
raise ValueError(
"Evaluation with epsilon-greedy exploration is only supported "
"with discrete action spaces."
)
eps = self.epsilon_greedy
new_prob = new_prob * (1 - eps) + eps / self.policy.action_space.n

return new_prob

@DeveloperAPI
def train(self, batch: SampleBatchType) -> Dict[str, Any]:
"""Train a model for Off-Policy Estimation.
Expand Down
8 changes: 4 additions & 4 deletions rllib/offline/estimators/tests/test_ope.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def setUpClass(cls):
evaluation_num_workers=1,
evaluation_duration_unit="episodes",
off_policy_estimation_methods={
"is": {"type": ImportanceSampling},
"wis": {"type": WeightedImportanceSampling},
"dm_fqe": {"type": DirectMethod},
"dr_fqe": {"type": DoublyRobust},
"is": {"type": ImportanceSampling, "epsilon_greedy": 0.1},
"wis": {"type": WeightedImportanceSampling, "epsilon_greedy": 0.1},
"dm_fqe": {"type": DirectMethod, "epsilon_greedy": 0.1},
"dr_fqe": {"type": DoublyRobust, "epsilon_greedy": 0.1},
},
)
)
Expand Down
12 changes: 4 additions & 8 deletions rllib/offline/estimators/weighted_importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@

from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.numpy import convert_to_numpy


@DeveloperAPI
Expand All @@ -29,8 +27,8 @@ class WeightedImportanceSampling(OffPolicyEstimator):
For more information refer to https://arxiv.org/pdf/1911.06854.pdf"""

@override(OffPolicyEstimator)
def __init__(self, policy: Policy, gamma: float):
super().__init__(policy, gamma)
def __init__(self, policy: Policy, gamma: float, epsilon_greedy: float = 0.0):
super().__init__(policy, gamma, epsilon_greedy)
# map from time to cummulative propensity values
self.cummulative_ips_values = []
# map from time to number of episodes that reached this time
Expand Down Expand Up @@ -70,8 +68,7 @@ def estimate_on_single_step_samples(
) -> Dict[str, List[float]]:
estimates_per_epsiode = {}
rewards, old_prob = batch["rewards"], batch["action_prob"]
log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch)
new_prob = np.exp(convert_to_numpy(log_likelihoods))
new_prob = self.compute_action_probs(batch)

weights = new_prob / old_prob
v_behavior = rewards
Expand All @@ -95,8 +92,7 @@ def on_before_split_batch_by_episode(
@override(OffPolicyEstimator)
def peek_on_single_episode(self, episode: SampleBatch) -> None:
old_prob = episode["action_prob"]
log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, episode)
new_prob = np.exp(convert_to_numpy(log_likelihoods))
new_prob = self.compute_action_probs(episode)

# calculate importance ratios
episode_p = []
Expand Down

0 comments on commit c1e0d39

Please sign in to comment.