Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Jun Gong committed Feb 22, 2022
1 parent 1f261fc commit 45956d9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
18 changes: 17 additions & 1 deletion rllib/examples/bandit/tune_lin_ucb_train_recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,23 @@

import ray
from ray import tune
from ray.rllib.examples.env import bandit_envs_recommender_system # noqa: F401
from ray.tune import register_env
from ray.rllib.env.wrappers.recsim import (
MultiDiscreteToDiscreteActionWrapper,
RecSimObservationBanditWrapper,
)
from ray.rllib.examples.env.bandit_envs_recommender_system import (
ParametricRecSys,
)

# Because ParametricRecSys follows RecSim's API, we have to wrap it before
# it can work with our Bandits agent.
register_env(
"ParametricRecSysEnv",
lambda cfg: MultiDiscreteToDiscreteActionWrapper(
RecSimObservationBanditWrapper(ParametricRecSys(**cfg))
),
)

if __name__ == "__main__":
# Temp fix to avoid OMP conflict.
Expand Down
23 changes: 8 additions & 15 deletions rllib/examples/env/bandit_envs_recommender_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
import numpy as np
from typing import Optional

from ray.rllib.env.wrappers.recsim import (
MultiDiscreteToDiscreteActionWrapper,
RecSimObservationBanditWrapper,
)
from ray.rllib.utils.numpy import softmax
from ray.tune import register_env


class ParametricRecSys(gym.Env):
"""A recommendation environment which generates items with visible features
randomly (parametric actions).
The environment can be configured to be multi-user, i.e. different models
will be learned independently for each user, by setting num_users_in_db
parameter.
To enable slate recommendation, the `slate_size` config parameter can be
set as > 1.
"""
def __init__(
self,
embedding_size: int = 20,
Expand Down Expand Up @@ -191,16 +194,6 @@ def _get_obs(self, response=None):
}


# Because ParametricRecSys follows RecSim's API, we have to wrap it before
# it can work with our Bandits agent.
register_env(
"ParametricRecSysEnv",
lambda config: MultiDiscreteToDiscreteActionWrapper(
RecSimObservationBanditWrapper(ParametricRecSys(**config))
),
)


if __name__ == "__main__":
"""Test RecommSys env with random actions for baseline performance."""
env = ParametricRecSys(
Expand Down

0 comments on commit 45956d9

Please sign in to comment.