diff --git a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py index f4fbbc70bfd1..cb7dc2c20c99 100644 --- a/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py +++ b/rllib/examples/bandit/tune_lin_ucb_train_recommendation.py @@ -8,7 +8,23 @@ import ray from ray import tune -from ray.rllib.examples.env import bandit_envs_recommender_system # noqa: F401 +from ray.tune import register_env +from ray.rllib.env.wrappers.recsim import ( + MultiDiscreteToDiscreteActionWrapper, + RecSimObservationBanditWrapper, +) +from ray.rllib.examples.env.bandit_envs_recommender_system import ( + ParametricRecSys, +) + +# Because ParametricRecSys follows RecSim's API, we have to wrap it before +# it can work with our Bandits agent. +register_env( + "ParametricRecSysEnv", + lambda cfg: MultiDiscreteToDiscreteActionWrapper( + RecSimObservationBanditWrapper(ParametricRecSys(**cfg)) + ), +) if __name__ == "__main__": # Temp fix to avoid OMP conflict. diff --git a/rllib/examples/env/bandit_envs_recommender_system.py b/rllib/examples/env/bandit_envs_recommender_system.py index 58eac2e8c97e..51036046f820 100644 --- a/rllib/examples/env/bandit_envs_recommender_system.py +++ b/rllib/examples/env/bandit_envs_recommender_system.py @@ -6,15 +6,18 @@ import numpy as np from typing import Optional -from ray.rllib.env.wrappers.recsim import ( - MultiDiscreteToDiscreteActionWrapper, - RecSimObservationBanditWrapper, -) from ray.rllib.utils.numpy import softmax -from ray.tune import register_env class ParametricRecSys(gym.Env): + """A recommendation environment which generates items with visible features + randomly (parametric actions). + The environment can be configured to be multi-user, i.e. different models + will be learned independently for each user, by setting num_users_in_db + parameter. + To enable slate recommendation, the `slate_size` config parameter can be + set as > 1. + """ def __init__( self, embedding_size: int = 20, @@ -191,16 +194,6 @@ def _get_obs(self, response=None): } -# Because ParametricRecSys follows RecSim's API, we have to wrap it before -# it can work with our Bandits agent. -register_env( - "ParametricRecSysEnv", - lambda config: MultiDiscreteToDiscreteActionWrapper( - RecSimObservationBanditWrapper(ParametricRecSys(**config)) - ), -) - - if __name__ == "__main__": """Test RecommSys env with random actions for baseline performance.""" env = ParametricRecSys(