diff --git a/rllib/BUILD b/rllib/BUILD index d4b4049300ac..152ddb6ae493 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -880,7 +880,7 @@ py_test( py_test( name = "test_bandits", tags = ["team:rllib", "algorithms_dir"], - size = "medium", + size = "large", srcs = ["algorithms/bandit/tests/test_bandits.py"], ) diff --git a/rllib/algorithms/bandit/bandit_tf_model.py b/rllib/algorithms/bandit/bandit_tf_model.py index dbc5c6dd690c..e4a4f951d8b1 100644 --- a/rllib/algorithms/bandit/bandit_tf_model.py +++ b/rllib/algorithms/bandit/bandit_tf_model.py @@ -32,12 +32,17 @@ def __init__(self, feature_dim, alpha=1, lambda_=1): ) self._init_params() + self.dist = self._make_dist() def _init_params(self): self.covariance.assign(self.covariance * self.alpha) - self.dist = tfp.distributions.MultivariateNormalTriL( + + def _make_dist(self): + """Create a multivariate normal distribution with the current parameters""" + dist = tfp.distributions.MultivariateNormalTriL( self.theta, scale_tril=tf.linalg.cholesky(self.covariance) ) + return dist def partial_fit(self, x, y): x, y = self._check_inputs(x, y) @@ -55,6 +60,11 @@ def partial_fit(self, x, y): self.covariance.assign(tf.linalg.inv(self.precision)) self.theta.assign(tf.linalg.matvec(self.covariance, self.f)) self.covariance.assign(self.covariance * self.alpha) + # the multivariate dist needs to be reconstructed every time + # its parameters are updated.the parameters of the dist do not + # update every time the stored self.covariance and self.theta + # (the mean) are updated. + self.dist = self._make_dist() def sample_theta(self): theta = self.dist.sample() diff --git a/rllib/algorithms/bandit/bandit_torch_model.py b/rllib/algorithms/bandit/bandit_torch_model.py index 48fd0855c422..5740fc36af2f 100644 --- a/rllib/algorithms/bandit/bandit_torch_model.py +++ b/rllib/algorithms/bandit/bandit_torch_model.py @@ -37,6 +37,7 @@ def __init__(self, feature_dim, alpha=1, lambda_=1): data=self.covariance.matmul(self.f), requires_grad=False ) self._init_params() + self.dist = self._make_dist() def _init_params(self): self.update_schedule = 1 @@ -44,9 +45,13 @@ def _init_params(self): self.delta_b = 0 self.time = 0 self.covariance.mul_(self.alpha) - self.dist = torch.distributions.multivariate_normal.MultivariateNormal( - self.theta, self.covariance + + def _make_dist(self): + """Create a multivariate normal distribution from the current parameters.""" + dist = torch.distributions.multivariate_normal.MultivariateNormal( + loc=self.theta, precision_matrix=self.precision ) + return dist def partial_fit(self, x, y): x, y = self._check_inputs(x, y) @@ -54,16 +59,21 @@ def partial_fit(self, x, y): y = y.item() self.time += 1 self.delta_f += y * x - self.delta_b += torch.ger(x, x) + self.delta_b += torch.outer(x, x) # Can follow an update schedule if not doing sherman morison updates if self.time % self.update_schedule == 0: self.precision += self.delta_b self.f += self.delta_f self.delta_b = 0 self.delta_f = 0 - torch.inverse(self.precision, out=self.covariance) - torch.matmul(self.covariance, self.f, out=self.theta) - self.covariance.mul_(self.alpha) + self.covariance.data = torch.inverse(self.precision) + self.theta.data = torch.matmul(self.covariance, self.f) + self.covariance.data *= self.alpha + # the multivariate dist needs to be reconstructed every time + # its parameters are updated.the parameters of the dist do not + # update every time the stored self.covariance and self.theta + # (the mean) are updated + self.dist = self._make_dist() def sample_theta(self): theta = self.dist.sample() diff --git a/rllib/algorithms/bandit/tests/test_bandits.py b/rllib/algorithms/bandit/tests/test_bandits.py index 4ae2647b7d11..6b79923c62a5 100644 --- a/rllib/algorithms/bandit/tests/test_bandits.py +++ b/rllib/algorithms/bandit/tests/test_bandits.py @@ -1,9 +1,35 @@ +import gym +from gym.spaces import Discrete, Box +import numpy as np import unittest import ray -from ray.rllib.algorithms.bandit import bandit +from ray.rllib.algorithms.bandit.bandit import BanditLinTSConfig, BanditLinUCBConfig from ray.rllib.examples.env.bandit_envs_discrete import SimpleContextualBandit -from ray.rllib.utils.test_utils import check_train_results, framework_iterator +from ray.rllib.env import EnvContext +from ray.rllib.utils.test_utils import check_train_results, framework_iterator, check +from ray.rllib.utils.numpy import convert_to_numpy + + +class NonContextualBanditEnv(gym.Env): + def __init__(self, config: EnvContext): + best_arm_prob = config.get("best_arm_prob", 0.5) + self.action_space = Discrete(2) + self.observation_space = Box(0.0, 1.0, shape=(1,), dtype=np.float32) + self.seed(0) + self._arm_probs = {0: 0.1, 1: best_arm_prob} + + def reset(self): + return [1.0] + + def step(self, action): + reward = self.rng.binomial(1, self._arm_probs[action]) + return ([1.0], reward, True, {}) + + def seed(self, seed=0): + self._seed = seed + if seed is not None: + self.rng = np.random.default_rng(self._seed) class TestBandits(unittest.TestCase): @@ -18,7 +44,7 @@ def tearDownClass(cls) -> None: def test_bandit_lin_ts_compilation(self): """Test whether BanditLinTS can be built on all frameworks.""" config = ( - bandit.BanditLinTSConfig() + BanditLinTSConfig() .environment(env=SimpleContextualBandit) .rollouts(num_rollout_workers=2, num_envs_per_worker=2) ) @@ -42,7 +68,7 @@ def test_bandit_lin_ts_compilation(self): def test_bandit_lin_ucb_compilation(self): """Test whether BanditLinUCB can be built on all frameworks.""" config = ( - bandit.BanditLinUCBConfig() + BanditLinUCBConfig() .environment(env=SimpleContextualBandit) .rollouts(num_envs_per_worker=2) ) @@ -64,6 +90,55 @@ def test_bandit_lin_ucb_compilation(self): self.assertTrue(results["episode_reward_mean"] == 10.0) algo.stop() + def test_bandit_convergence(self): + # test whether in a simple bandit environment, the bandit algorithm + # distribution converge to the optimal distribution empirically + + std_threshold = 0.1 + best_arm_prob = 0.5 + + for config_cls in [BanditLinUCBConfig, BanditLinTSConfig]: + config = ( + config_cls() + .debugging(seed=0) + .environment( + env=NonContextualBanditEnv, + env_config={"best_arm_prob": best_arm_prob}, + ) + ) + for _ in framework_iterator( + config, frameworks=("tf2", "torch"), with_eager_tracing=True + ): + algo = config.build() + model = algo.get_policy().model + arm_means, arm_stds = [], [] + for _ in range(50): + # TODO the internals of the model is leaking here. + # We should revisit this once the RLModule is merged in. + samples = [model.arms[i].dist.sample((1000,)) for i in range(2)] + arm_means.append( + [float(convert_to_numpy(s).mean(0)) for s in samples] + ) + arm_stds.append( + [float(convert_to_numpy(s).std(0)) for s in samples] + ) + algo.train() + + best_arm = np.argmax(arm_means[-1]) + print( + f"best arm: {best_arm}, arm means: {arm_means[-1]}, " + f"arm stds: {arm_stds[-1]}" + ) + + # the better arm (according to the learned model) should be + # sufficiently exploited so it should have a low variance at convergence + self.assertLess(arm_stds[-1][best_arm], std_threshold) + + # best arm should also have a good estimate of its actual mean + # Note that this may not be true for non-optimal arms as they may not + # have been explored enough + check(arm_means[-1][best_arm], best_arm_prob, decimals=1) + if __name__ == "__main__": import pytest