From 1ecab7d2011d4bff1964e43b598c8d572d393a13 Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Tue, 15 Feb 2022 11:14:06 -0800 Subject: [PATCH 1/8] Update bandit_envs_recommender_system with Sven's in-house implementation, which supports multiple users and slate sizes. Make sure example tune_lin_ucb_train_recommendation.py nails this environment. --- rllib/env/wrappers/recsim.py | 7 +- .../tune_lin_ucb_train_recommendation.py | 16 +- .../bandit/tune_lin_ucb_train_recsim_env.py | 6 +- .../env/bandit_envs_recommender_system.py | 359 ++++++++++-------- 4 files changed, 229 insertions(+), 159 deletions(-) diff --git a/rllib/env/wrappers/recsim.py b/rllib/env/wrappers/recsim.py index 2a29af6a1b66..4d237276233a 100644 --- a/rllib/env/wrappers/recsim.py +++ b/rllib/env/wrappers/recsim.py @@ -84,15 +84,12 @@ def __init__(self, env: gym.Env): self.observation_space = Dict( OrderedDict( [ - ("user", obs_space["user"]), ( "item", gym.spaces.Box( - low=-np.ones((num_items, embedding_dim)), - high=np.ones((num_items, embedding_dim)), + low=-1.0, high=1.0, shape=(num_items, embedding_dim) ), ), - ("response", obs_space["response"]), ] ) ) @@ -100,9 +97,7 @@ def __init__(self, env: gym.Env): def observation(self, obs): new_obs = OrderedDict() - new_obs["user"] = obs["user"] new_obs["item"] = np.vstack(list(obs["doc"].values())) - new_obs["response"] = obs["response"] new_obs = convert_element_to_space_type(new_obs, self._sampled_obs) return new_obs diff --git a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py index d9bcbf997d34..5aa616f21459 100644 --- a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py +++ b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py @@ -8,7 +8,7 @@ import ray from ray import tune -from ray.rllib.examples.env.bandit_envs_recommender_system import ParametricItemRecoEnv +from ray.rllib.examples.env.bandit_envs_recommender_system import ParametricRecSys if __name__ == "__main__": # Temp fix to avoid OMP conflict. @@ -17,8 +17,20 @@ ray.init() config = { - "env": ParametricItemRecoEnv, + "env": "ParametricRecSysEnv", + "env_config": { + "embedding_size": 20, + "num_docs_to_select_from": 10, + "slate_size": 1, + "num_docs_in_db": 100, + "num_users_in_db": 1, + "user_time_budget": 1.0, + }, "num_envs_per_worker": 2, # Test with batched inference. + "evaluation_interval": 20, + "evaluation_duration": 100, + "evaluation_duration_unit": "episodes", + "simple_optimizer": True, } # Actual training_iterations will be 10 * timesteps_per_iteration diff --git a/rllib/examples/bandit/tune_lin_ucb_train_recsim_env.py b/rllib/examples/bandit/tune_lin_ucb_train_recsim_env.py index aae2f33e2e8a..251a2ecc28da 100644 --- a/rllib/examples/bandit/tune_lin_ucb_train_recsim_env.py +++ b/rllib/examples/bandit/tune_lin_ucb_train_recsim_env.py @@ -21,14 +21,16 @@ # Then: "env": [the imported RecSim class] "env": "RecSim-v1", "env_config": { + "num_candidates": 10, + "slate_size": 1, "convert_to_discrete_action_space": True, "wrap_for_bandits": True, }, } # Actual training_iterations will be 10 * timesteps_per_iteration - # (100 by default) = 2,000 - training_iterations = 10 + # (100 by default) = 100,000 + training_iterations = 5000 print("Running training for %s time steps" % training_iterations) diff --git a/rllib/examples/env/bandit_envs_recommender_system.py b/rllib/examples/env/bandit_envs_recommender_system.py index 7a6dcc19860a..fce376fe8903 100644 --- a/rllib/examples/env/bandit_envs_recommender_system.py +++ b/rllib/examples/env/bandit_envs_recommender_system.py @@ -1,166 +1,227 @@ -import copy - +"""Examples for recommender system simulating envs ready to be used by + RLlib Trainers. + This env follows RecSim obs and action APIs. +""" import gym import numpy as np -from gym import spaces - -DEFAULT_RECO_CONFIG = { - "num_users": 1, - "num_items": 100, - "feature_dim": 16, - "slate_size": 1, - "num_candidates": 25, - "seed": 1, -} - - -class ParametricItemRecoEnv(gym.Env): - """A recommendation environment which generates items with visible features - randomly (parametric actions). - The environment can be configured to be multi-user, i.e. different models - will be learned independently for each user. - To enable slate recommendation, the `slate_size` config parameter can be - set as > 1. - """ - - def __init__(self, config=None): - self.config = copy.copy(DEFAULT_RECO_CONFIG) - if config is not None and type(config) == dict: - self.config.update(config) - - self.num_users = self.config["num_users"] - self.num_items = self.config["num_items"] - self.feature_dim = self.config["feature_dim"] - self.slate_size = self.config["slate_size"] - self.num_candidates = self.config["num_candidates"] - self.seed = self.config["seed"] - - assert ( - self.num_candidates <= self.num_items - ), "Size of candidate pool should be less than total no. of items" - assert ( - self.slate_size < self.num_candidates - ), "Slate size should be less than no. of candidate items" - - self.action_space = self._def_action_space() - self.observation_space = self._def_observation_space() - - self.current_user_id = 0 - self.item_pool = None - self.item_pool_ids = None - self.total_regret = 0 - - self._init_embeddings() - - def _init_embeddings(self): - self.item_embeddings = self._gen_normalized_embeddings( - self.num_items, self.feature_dim +from typing import List, Optional + +from ray.rllib.env.wrappers.recsim import ( + MultiDiscreteToDiscreteActionWrapper, + RecSimObservationBanditWrapper, +) +from ray.rllib.utils.numpy import softmax +from ray.tune import register_env + + +class ParametricRecSys(gym.Env): + def __init__( + self, + embedding_size: int = 20, + num_docs_to_select_from: int = 10, + slate_size: int = 1, + num_docs_in_db: Optional[int] = None, + num_users_in_db: Optional[int] = None, + user_time_budget: float = 60.0, + ): + """Initializes a ParametricRecSys instance. + + Args: + embedding_size: Embedding size for both users and docs. + Each value in the user/doc embeddings can have values between + -1.0 and 1.0. + num_docs_to_select_from: The number of documents to present to the + agent each timestep. The agent will then have to pick a slate + out of these. + slate_size: The size of the slate to recommend to the user at each + timestep. + num_docs_in_db: The total number of documents in the DB. Set this + to None, in case you would like to resample docs from an + infinite pool. + num_users_in_db: The total number of users in the DB. Set this to + None, in case you would like to resample users from an infinite + pool. + user_time_budget: The total time budget a user has throughout an + episode. Once this time budget is used up (through engagements + with clicked/selected documents), the episode ends. + """ + self.embedding_size = embedding_size + self.num_docs_to_select_from = num_docs_to_select_from + self.slate_size = slate_size + + self.num_docs_in_db = num_docs_in_db + self.docs_db = None + self.num_users_in_db = num_users_in_db + self.users_db = None + self.current_user = None + + self.user_time_budget = user_time_budget + self.current_user_budget = user_time_budget + + self.observation_space = gym.spaces.Dict( + { + # The D docs our agent sees at each timestep. + # It has to select a k-slate out of these. + "doc": gym.spaces.Dict( + { + str(i): gym.spaces.Box( + -1.0, 1.0, shape=(self.embedding_size,), dtype=np.float32 + ) + for i in range(self.num_docs_to_select_from) + } + ), + # The user engaging in this timestep/episode. + "user": gym.spaces.Box( + -1.0, 1.0, shape=(self.embedding_size,), dtype=np.float32 + ), + # For each item in the previous slate, was it clicked? + # If yes, how long was it being engaged with (e.g. watched)? + "response": gym.spaces.Tuple( + [ + gym.spaces.Dict( + { + # Clicked or not? + "click": gym.spaces.Discrete(2), + # Engagement time (how many minutes watched?). + "engagement": gym.spaces.Box( + 0.0, 100.0, shape=(), dtype=np.float32 + ), + } + ) + for _ in range(self.slate_size) + ] + ), + } ) - - # These are latent user features that will be hidden from the learning - # agent. They will be used for reward generation only - self.user_embeddings = self._gen_normalized_embeddings( - self.num_users, self.feature_dim + # Our action space is + self.action_space = gym.spaces.MultiDiscrete( + [self.num_docs_to_select_from for _ in range(self.slate_size)] ) - def _sample_user(self): - self.current_user_id = np.random.randint(0, self.num_users) - - def _gen_item_pool(self): - # Randomly generate a candidate list of items by sampling without - # replacement - self.item_pool_ids = np.random.choice( - np.arange(self.num_items), self.num_candidates, replace=False - ) - self.item_pool = self.item_embeddings[self.item_pool_ids].astype(np.float32) + def _get_embedding(self): + return np.random.uniform(-1, 1, size=(self.embedding_size,)).astype(np.float32) - @staticmethod - def _gen_normalized_embeddings(size, dim): - embeddings = np.random.rand(size, dim) - embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True).astype( - np.float32 - ) - return embeddings - - def _def_action_space(self): - if self.slate_size == 1: - return spaces.Discrete(self.num_candidates) + def reset(self): + # Reset the current user's time budget. + self.current_user_budget = self.user_time_budget + + # Sample a user for the next episode/session. + # Pick from a only-once-sampled user DB. + if self.num_users_in_db is not None: + if self.users_db is None: + self.users_db = [ + self._get_embedding() for _ in range(self.num_users_in_db) + ] + self.current_user = self.users_db[np.random.choice(self.num_users_in_db)] + # Pick from an infinite pool of users. else: - return spaces.MultiDiscrete([self.num_candidates] * self.slate_size) + self.current_user = self._get_embedding() + + return self._get_obs() - def _def_observation_space(self): - # Embeddings for each item in the candidate pool - item_obs_space = spaces.Box( - low=-np.inf, high=np.inf, shape=(self.num_candidates, self.feature_dim) + def step(self, action): + # Action is the suggested slate (indices of the docs in the + # suggested ones). + + scores = [ + np.dot(self.current_user, doc) for doc in self.currently_suggested_docs + ] + best_reward = np.max(scores) + + # User choice model: User picks a doc stochastically, + # where probs are dot products between user- and doc feature + # (categories) vectors (rewards). + # There is also a no-click doc whose weight is 0.0. + user_doc_overlaps = np.array([scores[a] for a in action] + [0.0]) + which_clicked = np.random.choice( + np.arange(self.slate_size + 1), p=softmax(user_doc_overlaps) ) - # Can be useful for collaborative filtering based agents - item_ids_obs_space = spaces.MultiDiscrete( - [self.num_items] * self.num_candidates + r = 0.0 + if which_clicked < self.slate_size: + # Reward is 1.0 - regret if clicked. 0.0 if not clicked. + regret = best_reward - user_doc_overlaps[which_clicked] + r = 1 - regret + # If anything clicked, deduct from the current user's time budget. + self.current_user_budget -= 1.0 + d = self.current_user_budget <= 0.0 + + # Compile response. + response = tuple( + { + "click": int(idx == which_clicked), + "engagement": r if idx == which_clicked else 0.0, + } + for idx in range(len(user_doc_overlaps) - 1) ) - # Can be either binary (clicks) or continuous feedback (watch time) - resp_space = spaces.Box(low=-1, high=1, shape=(self.slate_size,)) - - if self.num_users == 1: - return spaces.Dict( - { - "item": item_obs_space, - "item_id": item_ids_obs_space, - "response": resp_space, - } - ) + return self._get_obs(response=response), r, d, {} + + def _get_obs(self, response=None): + # Sample D docs from infinity or our pre-existing docs. + # Pick from a only-once-sampled docs DB. + if self.num_docs_in_db is not None: + if self.docs_db is None: + self.docs_db = [ + self._get_embedding() for _ in range(self.num_docs_in_db) + ] + self.currently_suggested_docs = [ + self.docs_db[doc_idx].astype(np.float32) + for doc_idx in np.random.choice( + self.num_docs_in_db, + size=(self.num_docs_to_select_from,), + replace=False, + ) + ] + # Pick from an infinite pool of docs. else: - user_obs_space = spaces.Discrete(self.num_users) - return spaces.Dict( - { - "user": user_obs_space, - "item": item_obs_space, - "item_id": item_ids_obs_space, - "response": resp_space, - } - ) + self.currently_suggested_docs = [ + self._get_embedding() for _ in range(self.num_docs_to_select_from) + ] - def step(self, action): - # Action can be a single action or a slate depending on slate size - assert self.action_space.contains( - action - ), "Action cannot be recognized. Please check the type and bounds." - - if self.slate_size == 1: - scores = self.item_pool.dot(self.user_embeddings[self.current_user_id]) - reward = scores[action] - regret = np.max(scores) - reward - self.total_regret += regret - - info = {"regret": regret} - - self.current_user_id = np.random.randint(0, self.num_users) - self._gen_item_pool() - - obs = { - "item": self.item_pool.astype(np.float32), - "item_id": self.item_pool_ids, - "response": [reward], - } - if self.num_users > 1: - obs["user"] = self.current_user_id - return obs, reward, True, info - else: - # TODO(saurabh3949):Handle slate recommendation using a click model - return None + doc = {str(i): d for i, d in enumerate(self.currently_suggested_docs)} - def reset(self): - self._sample_user() - self._gen_item_pool() - obs = { - "item": self.item_pool, - "item_id": self.item_pool_ids, - "response": [0] * self.slate_size, + if not response: + response = self.observation_space["response"].sample() + + return { + "user": self.current_user.astype(np.float32), + "doc": doc, + "response": response, } - if self.num_users > 1: - obs["user"] = self.current_user_id - return obs - def render(self, mode="human"): - raise NotImplementedError + +# Because ParametricRecSys follows RecSim's API, we have to wrap it before +# it can work with our Bandits agent. +register_env( + "ParametricRecSysEnv", + lambda config: MultiDiscreteToDiscreteActionWrapper( + RecSimObservationBanditWrapper(ParametricRecSys(**config)) + ), +) + + +if __name__ == "__main__": + """Test RecommSys env with random actions for baseline performance.""" + env = ParametricRecSys( + num_docs_in_db=100, + num_users_in_db=1, + ) + obs = env.reset() + num_episodes = 0 + episode_rewards = [] + episode_reward = 0.0 + + while num_episodes < 100: + action = env.action_space.sample() + obs, reward, done, _ = env.step(action) + + episode_reward += reward + if done: + print(f"episode reward = {episode_reward}") + env.reset() + num_episodes += 1 + episode_rewards.append(episode_reward) + episode_reward = 0.0 + + print(f"Avg reward={np.mean(episode_rewards)}") From 94c1f5775be259c4bc4dc9b8a80751e22f2e6664 Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Tue, 15 Feb 2022 21:48:54 -0800 Subject: [PATCH 2/8] minor --- rllib/examples/bandit/tune_lin_ucb_train_recsim_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/examples/bandit/tune_lin_ucb_train_recsim_env.py b/rllib/examples/bandit/tune_lin_ucb_train_recsim_env.py index 251a2ecc28da..5dd56fe0eb63 100644 --- a/rllib/examples/bandit/tune_lin_ucb_train_recsim_env.py +++ b/rllib/examples/bandit/tune_lin_ucb_train_recsim_env.py @@ -29,7 +29,7 @@ } # Actual training_iterations will be 10 * timesteps_per_iteration - # (100 by default) = 100,000 + # (100 by default) = 500,000 training_iterations = 5000 print("Running training for %s time steps" % training_iterations) From 1f261fc0bee25514460ce51ac374a6c1b62e961a Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Wed, 16 Feb 2022 16:29:33 -0800 Subject: [PATCH 3/8] lint --- rllib/examples/bandit/tune_lin_ucb_train_recommendation.py | 2 +- rllib/examples/env/bandit_envs_recommender_system.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py index 5aa616f21459..f4fbbc70bfd1 100644 --- a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py +++ b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py @@ -8,7 +8,7 @@ import ray from ray import tune -from ray.rllib.examples.env.bandit_envs_recommender_system import ParametricRecSys +from ray.rllib.examples.env import bandit_envs_recommender_system # noqa: F401 if __name__ == "__main__": # Temp fix to avoid OMP conflict. diff --git a/rllib/examples/env/bandit_envs_recommender_system.py b/rllib/examples/env/bandit_envs_recommender_system.py index fce376fe8903..58eac2e8c97e 100644 --- a/rllib/examples/env/bandit_envs_recommender_system.py +++ b/rllib/examples/env/bandit_envs_recommender_system.py @@ -4,7 +4,7 @@ """ import gym import numpy as np -from typing import List, Optional +from typing import Optional from ray.rllib.env.wrappers.recsim import ( MultiDiscreteToDiscreteActionWrapper, From 45956d9a27f1d2909f260a2e081706ec33b61d32 Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Thu, 17 Feb 2022 15:06:37 -0800 Subject: [PATCH 4/8] wip --- .../tune_lin_ucb_train_recommendation.py | 18 ++++++++++++++- .../env/bandit_envs_recommender_system.py | 23 +++++++------------ 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py index f4fbbc70bfd1..cb7dc2c20c99 100644 --- a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py +++ b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py @@ -8,7 +8,23 @@ import ray from ray import tune -from ray.rllib.examples.env import bandit_envs_recommender_system # noqa: F401 +from ray.tune import register_env +from ray.rllib.env.wrappers.recsim import ( + MultiDiscreteToDiscreteActionWrapper, + RecSimObservationBanditWrapper, +) +from ray.rllib.examples.env.bandit_envs_recommender_system import ( + ParametricRecSys, +) + +# Because ParametricRecSys follows RecSim's API, we have to wrap it before +# it can work with our Bandits agent. +register_env( + "ParametricRecSysEnv", + lambda cfg: MultiDiscreteToDiscreteActionWrapper( + RecSimObservationBanditWrapper(ParametricRecSys(**cfg)) + ), +) if __name__ == "__main__": # Temp fix to avoid OMP conflict. diff --git a/rllib/examples/env/bandit_envs_recommender_system.py b/rllib/examples/env/bandit_envs_recommender_system.py index 58eac2e8c97e..51036046f820 100644 --- a/rllib/examples/env/bandit_envs_recommender_system.py +++ b/rllib/examples/env/bandit_envs_recommender_system.py @@ -6,15 +6,18 @@ import numpy as np from typing import Optional -from ray.rllib.env.wrappers.recsim import ( - MultiDiscreteToDiscreteActionWrapper, - RecSimObservationBanditWrapper, -) from ray.rllib.utils.numpy import softmax -from ray.tune import register_env class ParametricRecSys(gym.Env): + """A recommendation environment which generates items with visible features + randomly (parametric actions). + The environment can be configured to be multi-user, i.e. different models + will be learned independently for each user, by setting num_users_in_db + parameter. + To enable slate recommendation, the `slate_size` config parameter can be + set as > 1. + """ def __init__( self, embedding_size: int = 20, @@ -191,16 +194,6 @@ def _get_obs(self, response=None): } -# Because ParametricRecSys follows RecSim's API, we have to wrap it before -# it can work with our Bandits agent. -register_env( - "ParametricRecSysEnv", - lambda config: MultiDiscreteToDiscreteActionWrapper( - RecSimObservationBanditWrapper(ParametricRecSys(**config)) - ), -) - - if __name__ == "__main__": """Test RecommSys env with random actions for baseline performance.""" env = ParametricRecSys( From d0a1bdf692d461245f0dffe8aab96c20ee19e701 Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Fri, 18 Feb 2022 09:45:08 -0800 Subject: [PATCH 5/8] lint --- rllib/examples/bandit/tune_lin_ucb_train_recommendation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py index cb7dc2c20c99..9e7c1cdb090d 100644 --- a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py +++ b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py @@ -6,7 +6,6 @@ import pandas as pd import time -import ray from ray import tune from ray.tune import register_env from ray.rllib.env.wrappers.recsim import ( From 6104e3182eb3a76f8b9f41b1f092044e66575029 Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Tue, 22 Feb 2022 15:23:02 -0800 Subject: [PATCH 6/8] lint --- rllib/examples/bandit/tune_lin_ucb_train_recommendation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py index 9e7c1cdb090d..cb7dc2c20c99 100644 --- a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py +++ b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py @@ -6,6 +6,7 @@ import pandas as pd import time +import ray from ray import tune from ray.tune import register_env from ray.rllib.env.wrappers.recsim import ( From 66af3c2ec19e3f01e5d7e9f5ef02d8aba00dea34 Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Wed, 23 Feb 2022 00:08:56 -0800 Subject: [PATCH 7/8] lint --- rllib/examples/env/bandit_envs_recommender_system.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rllib/examples/env/bandit_envs_recommender_system.py b/rllib/examples/env/bandit_envs_recommender_system.py index 51036046f820..7a8f9f75fb30 100644 --- a/rllib/examples/env/bandit_envs_recommender_system.py +++ b/rllib/examples/env/bandit_envs_recommender_system.py @@ -18,6 +18,7 @@ class ParametricRecSys(gym.Env): To enable slate recommendation, the `slate_size` config parameter can be set as > 1. """ + def __init__( self, embedding_size: int = 20, From 870962c6b4ebc605ccca5342a61dabb33aac5a98 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 24 Feb 2022 21:25:25 +0100 Subject: [PATCH 8/8] Apply suggestions from code review --- rllib/examples/env/bandit_envs_recommender_system.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rllib/examples/env/bandit_envs_recommender_system.py b/rllib/examples/env/bandit_envs_recommender_system.py index 7a8f9f75fb30..a9ad740bebab 100644 --- a/rllib/examples/env/bandit_envs_recommender_system.py +++ b/rllib/examples/env/bandit_envs_recommender_system.py @@ -141,25 +141,25 @@ def step(self, action): np.arange(self.slate_size + 1), p=softmax(user_doc_overlaps) ) - r = 0.0 + reward = 0.0 if which_clicked < self.slate_size: # Reward is 1.0 - regret if clicked. 0.0 if not clicked. regret = best_reward - user_doc_overlaps[which_clicked] - r = 1 - regret + reward = 1 - regret # If anything clicked, deduct from the current user's time budget. self.current_user_budget -= 1.0 - d = self.current_user_budget <= 0.0 + done = self.current_user_budget <= 0.0 # Compile response. response = tuple( { "click": int(idx == which_clicked), - "engagement": r if idx == which_clicked else 0.0, + "engagement": reward if idx == which_clicked else 0.0, } for idx in range(len(user_doc_overlaps) - 1) ) - return self._get_obs(response=response), r, d, {} + return self._get_obs(response=response), reward, done, {} def _get_obs(self, response=None): # Sample D docs from infinity or our pre-existing docs.