diff --git a/rllib/offline/estimators/direct_method.py b/rllib/offline/estimators/direct_method.py index 8c0def3e48b3..c26f01aa71bc 100644 --- a/rllib/offline/estimators/direct_method.py +++ b/rllib/offline/estimators/direct_method.py @@ -34,6 +34,7 @@ def __init__( self, policy: Policy, gamma: float, + epsilon_greedy: float = 0.0, q_model_config: Optional[Dict] = None, ): """Initializes a Direct Method OPE Estimator. @@ -41,15 +42,18 @@ def __init__( Args: policy: Policy to evaluate. gamma: Discount factor of the environment. + epsilon_greedy: The probability by which we act acording to a fully random + policy during deployment. With 1-epsilon_greedy we act according the + target policy. q_model_config: Arguments to specify the Q-model. Must specify - a `type` key pointing to the Q-model class. - This Q-model is trained in the train() method and is used - to compute the state-value estimates for the DirectMethod estimator. - It must implement `train` and `estimate_v`. - TODO (Rohan138): Unify this with RLModule API. + a `type` key pointing to the Q-model class. + This Q-model is trained in the train() method and is used + to compute the state-value estimates for the DirectMethod estimator. + It must implement `train` and `estimate_v`. + TODO (Rohan138): Unify this with RLModule API. """ - super().__init__(policy, gamma) + super().__init__(policy, gamma, epsilon_greedy) q_model_config = q_model_config or {} model_cls = q_model_config.pop("type", FQETorchModel) diff --git a/rllib/offline/estimators/doubly_robust.py b/rllib/offline/estimators/doubly_robust.py index 6a0e0ffcc926..79dfcfc1af66 100644 --- a/rllib/offline/estimators/doubly_robust.py +++ b/rllib/offline/estimators/doubly_robust.py @@ -7,7 +7,6 @@ from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.typing import SampleBatchType from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel @@ -46,6 +45,7 @@ def __init__( self, policy: Policy, gamma: float, + epsilon_greedy: float = 0.0, q_model_config: Optional[Dict] = None, ): """Initializes a Doubly Robust OPE Estimator. @@ -53,16 +53,19 @@ def __init__( Args: policy: Policy to evaluate. gamma: Discount factor of the environment. + epsilon_greedy: The probability by which we act acording to a fully random + policy during deployment. With 1-epsilon_greedy we act + according the target policy. q_model_config: Arguments to specify the Q-model. Must specify - a `type` key pointing to the Q-model class. - This Q-model is trained in the train() method and is used - to compute the state-value and Q-value estimates - for the DoublyRobust estimator. - It must implement `train`, `estimate_q`, and `estimate_v`. - TODO (Rohan138): Unify this with RLModule API. + a `type` key pointing to the Q-model class. + This Q-model is trained in the train() method and is used + to compute the state-value and Q-value estimates + for the DoublyRobust estimator. + It must implement `train`, `estimate_q`, and `estimate_v`. + TODO (Rohan138): Unify this with RLModule API. """ - super().__init__(policy, gamma) + super().__init__(policy, gamma, epsilon_greedy) q_model_config = q_model_config or {} model_cls = q_model_config.pop("type", FQETorchModel) @@ -83,8 +86,7 @@ def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]: estimates_per_epsiode = {} rewards, old_prob = episode["rewards"], episode["action_prob"] - log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, episode) - new_prob = np.exp(convert_to_numpy(log_likelihoods)) + new_prob = self.compute_action_probs(episode) v_behavior = 0.0 v_target = 0.0 @@ -113,8 +115,7 @@ def estimate_on_single_step_samples( estimates_per_epsiode = {} rewards, old_prob = batch["rewards"], batch["action_prob"] - log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch) - new_prob = np.exp(convert_to_numpy(log_likelihoods)) + new_prob = self.compute_action_probs(batch) q_values = self.model.estimate_q(batch) q_values = convert_to_numpy(q_values) diff --git a/rllib/offline/estimators/importance_sampling.py b/rllib/offline/estimators/importance_sampling.py index 162d35a41dea..1fed5deb1658 100644 --- a/rllib/offline/estimators/importance_sampling.py +++ b/rllib/offline/estimators/importance_sampling.py @@ -1,10 +1,7 @@ from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict from typing import Dict, List -import numpy as np @DeveloperAPI @@ -28,8 +25,7 @@ def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, float]: estimates_per_epsiode = {} rewards, old_prob = episode["rewards"], episode["action_prob"] - log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, episode) - new_prob = np.exp(convert_to_numpy(log_likelihoods)) + new_prob = self.compute_action_probs(episode) # calculate importance ratios p = [] @@ -59,8 +55,7 @@ def estimate_on_single_step_samples( estimates_per_epsiode = {} rewards, old_prob = batch["rewards"], batch["action_prob"] - log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch) - new_prob = np.exp(convert_to_numpy(log_likelihoods)) + new_prob = self.compute_action_probs(batch) weights = new_prob / old_prob v_behavior = rewards diff --git a/rllib/offline/estimators/off_policy_estimator.py b/rllib/offline/estimators/off_policy_estimator.py index 05dcdacc1224..8b35820aca55 100644 --- a/rllib/offline/estimators/off_policy_estimator.py +++ b/rllib/offline/estimators/off_policy_estimator.py @@ -1,3 +1,4 @@ +import gym import numpy as np import tree from typing import Dict, Any, List @@ -12,6 +13,7 @@ from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict from ray.rllib.utils.annotations import ( DeveloperAPI, + ExperimentalAPI, OverrideToImplementCustomLogic, ) from ray.rllib.utils.deprecation import Deprecated @@ -27,15 +29,25 @@ class OffPolicyEstimator(OfflineEvaluator): """Interface for an off policy estimator for counterfactual evaluation.""" @DeveloperAPI - def __init__(self, policy: Policy, gamma: float = 0.0): + def __init__( + self, + policy: Policy, + gamma: float = 0.0, + epsilon_greedy: float = 0.0, + ): """Initializes an OffPolicyEstimator instance. Args: policy: Policy to evaluate. gamma: Discount factor of the environment. + epsilon_greedy: The probability by which we act acording to a fully random + policy during deployment. With 1-epsilon_greedy we act according the target + policy. + # TODO (kourosh): convert the input parameters to a config dict. """ self.policy = policy self.gamma = gamma + self.epsilon_greedy = epsilon_greedy @DeveloperAPI def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]: @@ -228,6 +240,22 @@ def check_action_prob_in_batch(self, batch: SampleBatchType) -> None: "`off_policy_estimation_methods: {}` to disable estimation." ) + @ExperimentalAPI + def compute_action_probs(self, batch: SampleBatch): + log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch) + new_prob = np.exp(convert_to_numpy(log_likelihoods)) + + if self.epsilon_greedy > 0.0: + if not isinstance(self.policy.action_space, gym.spaces.Discrete): + raise ValueError( + "Evaluation with epsilon-greedy exploration is only supported " + "with discrete action spaces." + ) + eps = self.epsilon_greedy + new_prob = new_prob * (1 - eps) + eps / self.policy.action_space.n + + return new_prob + @DeveloperAPI def train(self, batch: SampleBatchType) -> Dict[str, Any]: """Train a model for Off-Policy Estimation. diff --git a/rllib/offline/estimators/tests/test_ope.py b/rllib/offline/estimators/tests/test_ope.py index 53c43b27e34c..e8e3e4b98455 100644 --- a/rllib/offline/estimators/tests/test_ope.py +++ b/rllib/offline/estimators/tests/test_ope.py @@ -64,10 +64,10 @@ def setUpClass(cls): evaluation_num_workers=1, evaluation_duration_unit="episodes", off_policy_estimation_methods={ - "is": {"type": ImportanceSampling}, - "wis": {"type": WeightedImportanceSampling}, - "dm_fqe": {"type": DirectMethod}, - "dr_fqe": {"type": DoublyRobust}, + "is": {"type": ImportanceSampling, "epsilon_greedy": 0.1}, + "wis": {"type": WeightedImportanceSampling, "epsilon_greedy": 0.1}, + "dm_fqe": {"type": DirectMethod, "epsilon_greedy": 0.1}, + "dr_fqe": {"type": DoublyRobust, "epsilon_greedy": 0.1}, }, ) ) diff --git a/rllib/offline/estimators/weighted_importance_sampling.py b/rllib/offline/estimators/weighted_importance_sampling.py index 5854e2b48aa4..c03923113f22 100644 --- a/rllib/offline/estimators/weighted_importance_sampling.py +++ b/rllib/offline/estimators/weighted_importance_sampling.py @@ -3,10 +3,8 @@ from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict from ray.rllib.policy import Policy from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.numpy import convert_to_numpy @DeveloperAPI @@ -29,8 +27,8 @@ class WeightedImportanceSampling(OffPolicyEstimator): For more information refer to https://arxiv.org/pdf/1911.06854.pdf""" @override(OffPolicyEstimator) - def __init__(self, policy: Policy, gamma: float): - super().__init__(policy, gamma) + def __init__(self, policy: Policy, gamma: float, epsilon_greedy: float = 0.0): + super().__init__(policy, gamma, epsilon_greedy) # map from time to cummulative propensity values self.cummulative_ips_values = [] # map from time to number of episodes that reached this time @@ -70,8 +68,7 @@ def estimate_on_single_step_samples( ) -> Dict[str, List[float]]: estimates_per_epsiode = {} rewards, old_prob = batch["rewards"], batch["action_prob"] - log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch) - new_prob = np.exp(convert_to_numpy(log_likelihoods)) + new_prob = self.compute_action_probs(batch) weights = new_prob / old_prob v_behavior = rewards @@ -95,8 +92,7 @@ def on_before_split_batch_by_episode( @override(OffPolicyEstimator) def peek_on_single_episode(self, episode: SampleBatch) -> None: old_prob = episode["action_prob"] - log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, episode) - new_prob = np.exp(convert_to_numpy(log_likelihoods)) + new_prob = self.compute_action_probs(episode) # calculate importance ratios episode_p = []