Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[RLlib] Enable Bandits to work in batches mode(s) (vector env… #22497

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ py_test(
py_test(
name = "test_bandits",
tags = ["team:ml", "trainers_dir"],
size = "medium",
size = "small",
srcs = ["agents/bandit/tests/test_bandits.py"],
)

Expand Down
87 changes: 46 additions & 41 deletions rllib/agents/bandit/bandit_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _init_params(self):
)

def partial_fit(self, x, y):
# TODO: Handle batch of data rather than individual points
x, y = self._check_inputs(x, y)
x = x.squeeze(0)
y = y.item()
Expand All @@ -61,31 +62,19 @@ def sample_theta(self):
theta = self.dist.sample()
return theta

def get_ucbs(self, x: torch.Tensor):
def get_ucbs(self, x):
"""Calculate upper confidence bounds using covariance matrix according
to algorithm 1: LinUCB
(http://proceedings.mlr.press/v15/chu11a/chu11a.pdf).

Args:
x: Input feature tensor of shape
(batch_size, [num_items]?, feature_dim)
x (torch.Tensor): Input feature tensor of shape
(batch_size, feature_dim)
"""
# Fold batch and num-items dimensions into one dim.
if len(x.shape) == 3:
B, C, F = x.shape
x_folded_batch = x.reshape([-1, F])
# Only batch and feature dims.
else:
x_folded_batch = x

projections = self.covariance @ x_folded_batch.T
batch_dots = (x_folded_batch * projections.T).sum(dim=-1)
batch_dots = batch_dots.sqrt()

# Restore original B and C dimensions.
if len(x.shape) == 3:
batch_dots = batch_dots.reshape([B, C])
return batch_dots
projections = self.covariance @ x.T
batch_dots = (x * projections.T).sum(dim=1)
return batch_dots.sqrt()

def forward(self, x, sample_theta=False):
"""Predict scores on input batch using the underlying linear model.
Expand All @@ -104,15 +93,19 @@ def forward(self, x, sample_theta=False):

def _check_inputs(self, x, y=None):
assert x.ndim in [2, 3], (
"Input context tensor must be 2 (no batch) or 3 dimensional (where the"
" first dimension is the batch size)."
"Input context tensor must be 2 or 3 dimensional, where the"
" first dimension is batch size"
)
assert x.shape[-1] == self.d, (
assert x.shape[1] == self.d, (
"Feature dimensions of weights ({}) and context ({}) do not "
"match!".format(self.d, x.shape[-1])
"match!".format(self.d, x.shape[1])
)
if y is not None:
assert torch.is_tensor(y), f"ERROR: Target should be a tensor, but is {y}!"
if y:
assert torch.is_tensor(y) and y.numel() == 1, (
"Target should be a tensor;"
"Only online learning with a batch size of 1 is "
"supported for now!"
)
return x if y is None else (x, y)


Expand Down Expand Up @@ -157,14 +150,13 @@ def predict(self, x, sample_theta=False, use_ucb=False):
else:
return scores

def partial_fit(self, x, y, arms):
for i, arm in enumerate(arms):
assert (
0 <= arm.item() < len(self.arms)
), "Invalid arm: {}. It should be 0 <= arm < {}".format(
arm.item(), len(self.arms)
)
self.arms[arm].partial_fit(x[[i]], y[[i]])
def partial_fit(self, x, y, arm):
assert (
0 <= arm.item() < len(self.arms)
), "Invalid arm: {}. It should be 0 <= arm < {}".format(
arm.item(), len(self.arms)
)
self.arms[arm].partial_fit(x, y)

@override(ModelV2)
def value_function(self):
Expand Down Expand Up @@ -219,16 +211,22 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name):
self._cur_ctx = None

def _check_inputs(self, x):
assert (
x.ndim == 3
), f"ERROR: Inputs ({x}) must have 3 dimensions (B x num-items x features)."
if x.ndim == 3 and x.size()[0] != 1:
# Just a test batch, slice to index 0.
if torch.all(x == 0.0):
x = x[0:1]
# An actual batch -> Error.
else:
raise ValueError("Only batch size of 1 is supported for now.")
return x

@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
x = input_dict["obs"]["item"]
x = self._check_inputs(x)
x.squeeze_(dim=0) # Remove the batch dimension
scores = self.predict(x)
scores.unsqueeze_(dim=0) # Add the batch dimension
return scores, state

def predict(self, x, sample_theta=False, use_ucb=False):
Expand All @@ -241,11 +239,10 @@ def predict(self, x, sample_theta=False, use_ucb=False):
else:
return scores

def partial_fit(self, x, y, arms):
def partial_fit(self, x, y, arm):
x = x["item"]
for i, arm in enumerate(arms):
action_id = arm.item()
self.arm.partial_fit(x[[i], action_id], y[[i]])
action_id = arm.item()
self.arm.partial_fit(x[:, action_id], y)

@override(ModelV2)
def value_function(self):
Expand All @@ -261,13 +258,21 @@ class ParametricLinearModelUCB(ParametricLinearModel):
def forward(self, input_dict, state, seq_lens):
x = input_dict["obs"]["item"]
x = self._check_inputs(x)
scores = super().predict(x, sample_theta=False, use_ucb=True)
x.squeeze_(dim=0) # Remove the batch dimension
scores = super(ParametricLinearModelUCB, self).predict(
x, sample_theta=False, use_ucb=True
)
scores.unsqueeze_(dim=0) # Add the batch dimension
return scores, state


class ParametricLinearModelThompsonSampling(ParametricLinearModel):
def forward(self, input_dict, state, seq_lens):
x = input_dict["obs"]["item"]
x = self._check_inputs(x)
scores = super().predict(x, sample_theta=True, use_ucb=False)
x.squeeze_(dim=0) # Remove the batch dimension
scores = super(ParametricLinearModelThompsonSampling, self).predict(
x, sample_theta=True, use_ucb=False
)
scores.unsqueeze_(dim=0) # Add the batch dimension
return scores, state
45 changes: 18 additions & 27 deletions rllib/agents/bandit/tests/test_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,40 @@ def tearDownClass(cls) -> None:
def test_bandit_lin_ts_compilation(self):
"""Test whether a BanditLinTSTrainer can be built on all frameworks."""
config = {
# Use a simple bandit-friendly env.
# Use a simple bandit friendly env.
"env": SimpleContextualBandit,
"num_envs_per_worker": 2, # Test batched inference.
"num_workers": 2, # Test distributed bandits.
}

num_iterations = 5

for _ in framework_iterator(config, frameworks="torch"):
for train_batch_size in [1, 10]:
config["train_batch_size"] = train_batch_size
trainer = bandit.BanditLinTSTrainer(config=config)
results = None
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
print(results)
# Force good learning behavior (this is a very simple env).
self.assertTrue(results["episode_reward_mean"] == 10.0)
trainer.stop()
trainer = bandit.BanditLinTSTrainer(config=config)
results = None
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
print(results)
# Force good learning behavior (this is a very simple env).
self.assertTrue(results["episode_reward_mean"] == 10.0)

def test_bandit_lin_ucb_compilation(self):
"""Test whether a BanditLinUCBTrainer can be built on all frameworks."""
config = {
# Use a simple bandit-friendly env.
# Use a simple bandit friendly env.
"env": SimpleContextualBandit,
"num_envs_per_worker": 2, # Test batched inference.
}

num_iterations = 5

for _ in framework_iterator(config, frameworks="torch"):
for train_batch_size in [1, 10]:
config["train_batch_size"] = train_batch_size
trainer = bandit.BanditLinUCBTrainer(config=config)
results = None
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
print(results)
# Force good learning behavior (this is a very simple env).
self.assertTrue(results["episode_reward_mean"] == 10.0)
trainer.stop()
trainer = bandit.BanditLinUCBTrainer(config=config)
results = None
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
print(results)
# Force good learning behavior (this is a very simple env).
self.assertTrue(results["episode_reward_mean"] == 10.0)

def test_deprecated_locations(self):
"""Tests, whether importing from old contrib dir fails gracefully.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,15 @@
import pandas as pd
import time

import ray
from ray import tune
from ray.rllib.examples.env.bandit_envs_recommender_system import ParametricItemRecoEnv

if __name__ == "__main__":
# Temp fix to avoid OMP conflict.
# Temp fix to avoid OMP conflict
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

ray.init()

config = {
"env": ParametricItemRecoEnv,
"num_envs_per_worker": 2, # Test with batched inference.
}

# Actual training_iterations will be 10 * timesteps_per_iteration
Expand Down