Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix cov matrix bug bandits #29867

Merged
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ py_test(
py_test(
name = "test_bandits",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
size = "large",
srcs = ["algorithms/bandit/tests/test_bandits.py"],
)

Expand Down
12 changes: 11 additions & 1 deletion rllib/algorithms/bandit/bandit_tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,17 @@ def __init__(self, feature_dim, alpha=1, lambda_=1):
)

self._init_params()
self.dist = self._make_dist()

def _init_params(self):
self.covariance.assign(self.covariance * self.alpha)
self.dist = tfp.distributions.MultivariateNormalTriL(

def _make_dist(self):
"""Create a multivariate normal distribution with the current parameters"""
dist = tfp.distributions.MultivariateNormalTriL(
self.theta, scale_tril=tf.linalg.cholesky(self.covariance)
)
return dist

def partial_fit(self, x, y):
x, y = self._check_inputs(x, y)
Expand All @@ -55,6 +60,11 @@ def partial_fit(self, x, y):
self.covariance.assign(tf.linalg.inv(self.precision))
self.theta.assign(tf.linalg.matvec(self.covariance, self.f))
self.covariance.assign(self.covariance * self.alpha)
# the multivariate dist needs to be reconstructed every time
# its parameters are updated.the parameters of the dist do not
# update every time the stored self.covariance and self.theta
# (the mean) are updated.
self.dist = self._make_dist()

def sample_theta(self):
theta = self.dist.sample()
Expand Down
22 changes: 16 additions & 6 deletions rllib/algorithms/bandit/bandit_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,33 +37,43 @@ def __init__(self, feature_dim, alpha=1, lambda_=1):
data=self.covariance.matmul(self.f), requires_grad=False
)
self._init_params()
self.dist = self._make_dist()

def _init_params(self):
self.update_schedule = 1
self.delta_f = 0
self.delta_b = 0
self.time = 0
self.covariance.mul_(self.alpha)
self.dist = torch.distributions.multivariate_normal.MultivariateNormal(
self.theta, self.covariance

def _make_dist(self):
"""Create a multivariate normal distribution from the current parameters."""
dist = torch.distributions.multivariate_normal.MultivariateNormal(
loc=self.theta, precision_matrix=self.precision
)
return dist

def partial_fit(self, x, y):
x, y = self._check_inputs(x, y)
x = x.squeeze(0)
y = y.item()
self.time += 1
self.delta_f += y * x
self.delta_b += torch.ger(x, x)
self.delta_b += torch.outer(x, x)
# Can follow an update schedule if not doing sherman morison updates
if self.time % self.update_schedule == 0:
self.precision += self.delta_b
self.f += self.delta_f
self.delta_b = 0
self.delta_f = 0
torch.inverse(self.precision, out=self.covariance)
torch.matmul(self.covariance, self.f, out=self.theta)
self.covariance.mul_(self.alpha)
self.covariance.data = torch.inverse(self.precision)
self.theta.data = torch.matmul(self.covariance, self.f)
self.covariance.data *= self.alpha
# the multivariate dist needs to be reconstructed every time
# its parameters are updated.the parameters of the dist do not
# update every time the stored self.covariance and self.theta
# (the mean) are updated
self.dist = self._make_dist()

def sample_theta(self):
theta = self.dist.sample()
Expand Down
83 changes: 79 additions & 4 deletions rllib/algorithms/bandit/tests/test_bandits.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,35 @@
import gym
from gym.spaces import Discrete, Box
import numpy as np
import unittest

import ray
from ray.rllib.algorithms.bandit import bandit
from ray.rllib.algorithms.bandit.bandit import BanditLinTSConfig, BanditLinUCBConfig
from ray.rllib.examples.env.bandit_envs_discrete import SimpleContextualBandit
from ray.rllib.utils.test_utils import check_train_results, framework_iterator
from ray.rllib.env import EnvContext
from ray.rllib.utils.test_utils import check_train_results, framework_iterator, check
from ray.rllib.utils.numpy import convert_to_numpy


class NonContextualBanditEnv(gym.Env):
def __init__(self, config: EnvContext):
best_arm_prob = config.get("best_arm_prob", 0.5)
self.action_space = Discrete(2)
self.observation_space = Box(0.0, 1.0, shape=(1,), dtype=np.float32)
self.seed(0)
self._arm_probs = {0: 0.1, 1: best_arm_prob}

def reset(self):
return [1.0]

def step(self, action):
reward = self.rng.binomial(1, self._arm_probs[action])
return ([1.0], reward, True, {})

def seed(self, seed=0):
self._seed = seed
if seed is not None:
self.rng = np.random.default_rng(self._seed)


class TestBandits(unittest.TestCase):
Expand All @@ -18,7 +44,7 @@ def tearDownClass(cls) -> None:
def test_bandit_lin_ts_compilation(self):
"""Test whether BanditLinTS can be built on all frameworks."""
config = (
bandit.BanditLinTSConfig()
BanditLinTSConfig()
.environment(env=SimpleContextualBandit)
.rollouts(num_rollout_workers=2, num_envs_per_worker=2)
)
Expand All @@ -42,7 +68,7 @@ def test_bandit_lin_ts_compilation(self):
def test_bandit_lin_ucb_compilation(self):
"""Test whether BanditLinUCB can be built on all frameworks."""
config = (
bandit.BanditLinUCBConfig()
BanditLinUCBConfig()
.environment(env=SimpleContextualBandit)
.rollouts(num_envs_per_worker=2)
)
Expand All @@ -64,6 +90,55 @@ def test_bandit_lin_ucb_compilation(self):
self.assertTrue(results["episode_reward_mean"] == 10.0)
algo.stop()

def test_bandit_convergence(self):
# test whether in a simple bandit environment, the bandit algorithm
# distribution converge to the optimal distribution empirically

std_threshold = 0.1
best_arm_prob = 0.5

for config_cls in [BanditLinUCBConfig, BanditLinTSConfig]:
config = (
config_cls()
.debugging(seed=0)
.environment(
env=NonContextualBanditEnv,
env_config={"best_arm_prob": best_arm_prob},
)
)
for _ in framework_iterator(
config, frameworks=("tf2", "torch"), with_eager_tracing=True
):
algo = config.build()
model = algo.get_policy().model
arm_means, arm_stds = [], []
for _ in range(50):
# TODO the internals of the model is leaking here.
# We should revisit this once the RLModule is merged in.
samples = [model.arms[i].dist.sample((1000,)) for i in range(2)]
arm_means.append(
[float(convert_to_numpy(s).mean(0)) for s in samples]
)
arm_stds.append(
[float(convert_to_numpy(s).std(0)) for s in samples]
)
algo.train()

best_arm = np.argmax(arm_means[-1])
print(
f"best arm: {best_arm}, arm means: {arm_means[-1]}, "
f"arm stds: {arm_stds[-1]}"
)

# the better arm (according to the learned model) should be
# sufficiently exploited so it should have a low variance at convergence
self.assertLess(arm_stds[-1][best_arm], std_threshold)

# best arm should also have a good estimate of its actual mean
# Note that this may not be true for non-optimal arms as they may not
# have been explored enough
check(arm_means[-1][best_arm], best_arm_prob, decimals=1)


if __name__ == "__main__":
import pytest
Expand Down