Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typehinting #180

Merged
merged 12 commits into from
Jun 15, 2024
101 changes: 58 additions & 43 deletions examples/bayesopt_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings

warnings.filterwarnings('ignore')

import numpy as np
Expand All @@ -10,7 +11,7 @@
from gpytorch import distributions as gdists
import torch.utils.data as data_utils
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
import tqdm

from laplace import Laplace
Expand Down Expand Up @@ -45,13 +46,13 @@ class LaplaceBNN(Model):
"""

def __init__(
self,
train_X: torch.Tensor,
train_Y: torch.Tensor,
bnn: Laplace = None,
likelihood: str = 'regression',
batch_size: int = 1024):

self,
train_X: torch.Tensor,
train_Y: torch.Tensor,
bnn: Laplace = None,
likelihood: str = 'regression',
batch_size: int = 1024,
):
super().__init__()

self.train_X = train_X
Expand All @@ -63,14 +64,13 @@ def __init__(
nn.ReLU(),
nn.Linear(50, 50),
nn.ReLU(),
nn.Linear(50, train_Y.shape[-1])
nn.Linear(50, train_Y.shape[-1]),
)
self.bnn = bnn

if self.bnn is None:
self._train_model(self._get_train_loader())


def posterior(
self,
X: torch.Tensor,
Expand All @@ -84,21 +84,21 @@ def posterior(
"""
# Transform to `(batch_shape*q, d)`
B, Q, D = X.shape
X = X.reshape(B*Q, D)
X = X.reshape(B * Q, D)

# Posterior predictive distribution
# mean_y is (batch_shape*q, k); cov_y is (batch_shape*q*k, batch_shape*q*k)
mean_y, cov_y = self._get_prediction(X, use_test_loader=False)

# Mean in `(batch_shape, q*k)`
K = self.num_outputs
mean_y = mean_y.reshape(B, Q*K)
mean_y = mean_y.reshape(B, Q * K)

# Cov is `(batch_shape, q*k, q*k)`
cov_y += 1e-4*torch.eye(B*Q*K)
cov_y += 1e-4 * torch.eye(B * Q * K)
cov_y = cov_y.reshape(B, Q, K, B, Q, K)
cov_y = torch.einsum('bqkbrl->bqkrl', cov_y) # (B, Q, K, Q, K)
cov_y = cov_y.reshape(B, Q*K, Q*K)
cov_y = cov_y.reshape(B, Q * K, Q * K)

dist = gdists.MultivariateNormal(mean_y, covariance_matrix=cov_y)
post_pred = GPyTorchPosterior(dist)
Expand All @@ -108,8 +108,9 @@ def posterior(

return post_pred


def condition_on_observations(self, X: torch.Tensor, Y: torch.Tensor, **kwargs: Any) -> Model:
def condition_on_observations(
self, X: torch.Tensor, Y: torch.Tensor, **kwargs: Any
) -> Model:
self.train_X = torch.cat([self.train_X, X], dim=0)
self.train_Y = torch.cat([self.train_Y, Y], dim=0)

Expand All @@ -118,8 +119,11 @@ def condition_on_observations(self, X: torch.Tensor, Y: torch.Tensor, **kwargs:

return LaplaceBNN(
# Added dataset & retrained BNN
self.train_X, self.train_Y, self.bnn,
self.likelihood, self.batch_size,
self.train_X,
self.train_Y,
self.bnn,
self.likelihood,
self.batch_size,
)

@property
Expand All @@ -133,7 +137,9 @@ def _train_model(self, train_loader):
"""
n_epochs = 1000
optimizer = optim.Adam(self.nn.parameters(), lr=1e-1, weight_decay=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs*len(train_loader))
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, n_epochs * len(train_loader)
)
loss_func = nn.MSELoss()

for i in range(n_epochs):
Expand All @@ -148,14 +154,15 @@ def _train_model(self, train_loader):
self.nn.eval()

self.bnn = Laplace(
self.nn, self.likelihood,
subset_of_weights='all', hessian_structure='kron',
enable_backprop=True
self.nn,
self.likelihood,
subset_of_weights='all',
hessian_structure='kron',
enable_backprop=True,
)
self.bnn.fit(train_loader)
self.bnn.optimize_prior_precision(n_steps=50)


def _get_prediction(self, test_x: torch.Tensor, joint=True, use_test_loader=False):
"""
Batched Laplace prediction.
Expand All @@ -175,7 +182,7 @@ def _get_prediction(self, test_x: torch.Tensor, joint=True, use_test_loader=Fals
else:
test_loader = data_utils.DataLoader(
data_utils.TensorDataset(test_x, torch.zeros_like(test_x)),
batch_size=256
batch_size=256,
)

mean_y, cov_y = [], []
Expand All @@ -190,13 +197,14 @@ def _get_prediction(self, test_x: torch.Tensor, joint=True, use_test_loader=Fals

return mean_y, cov_y


def _get_train_loader(self):
return data_utils.DataLoader(
data_utils.TensorDataset(self.train_X, self.train_Y),
batch_size=self.batch_size, shuffle=True
batch_size=self.batch_size,
shuffle=True,
)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--test_func', choices=['ackley', 'branin'], default='branin')
Expand Down Expand Up @@ -227,25 +235,30 @@ def _get_train_loader(self):

train_data_points = 20

train_x = torch.cat([
dists.Uniform(*bounds.T[i]).sample((train_data_points, 1))
for i in range(2) # for each dimension
], dim=1)
train_x = torch.cat(
[
dists.Uniform(*bounds.T[i]).sample((train_data_points, 1))
for i in range(2) # for each dimension
],
dim=1,
)
train_y = true_f(train_x).reshape(-1, 1)

test_x = torch.cat([
dists.Uniform(*bounds.T[i]).sample((10000, 1))
for i in range(2) # for each dimension
], dim=1)
test_x = torch.cat(
[
dists.Uniform(*bounds.T[i]).sample((10000, 1))
for i in range(2) # for each dimension
],
dim=1,
)
test_y = true_f(test_x)

models = {
'RandomSearch': None,
'RandomSearch': None,
'BNN-LA': LaplaceBNN(train_x, train_y),
'GP': SingleTaskGP(train_x, train_y),
}


def evaluate_model(model_name, model):
if model_name == 'GP':
pred = model.posterior(test_x).mean.squeeze()
Expand All @@ -256,7 +269,6 @@ def evaluate_model(model_name, model):

return F.mse_loss(pred, test_y).squeeze().item()


for model_name, model in models.items():
np.random.seed(args.randseed)
torch.set_default_dtype(torch.float64)
Expand All @@ -271,10 +283,13 @@ def evaluate_model(model_name, model):

for i in pbar:
if model_name == 'RandomSearch':
new_x = torch.cat([
dists.Uniform(*bounds.T[i]).sample((1, 1))
for i in range(len(bounds)) # for each dimension
], dim=1).squeeze()
new_x = torch.cat(
[
dists.Uniform(*bounds.T[i]).sample((1, 1))
for i in range(len(bounds)) # for each dimension
],
dim=1,
).squeeze()
else:
if args.acqf == 'EI':
acq_f = ExpectedImprovement(model, best_f=best_y, maximize=False)
Expand All @@ -292,12 +307,12 @@ def evaluate_model(model_name, model):
bounds=bounds,
q=1 if args.acqf not in ['qEI'] else 5,
num_restarts=10,
raw_samples=20
raw_samples=20,
)

if len(new_x.shape) == 1:
new_x = new_x.unsqueeze(0)

# Evaluate the objective on the proposed x
new_y = true_f(new_x).unsqueeze(-1) # (q, 1)

Expand Down
13 changes: 8 additions & 5 deletions examples/calibration_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
warnings.simplefilter("ignore", UserWarning)

warnings.simplefilter('ignore', UserWarning)

import torch
import torch.distributions as dists
Expand Down Expand Up @@ -50,9 +51,9 @@ def predict(dataloader, model, laplace=False):
print(f'[MAP] Acc.: {acc_map:.1%}; ECE: {ece_map:.1%}; NLL: {nll_map:.3}')

# Laplace
la = Laplace(model, 'classification',
subset_of_weights='last_layer',
hessian_structure='kron')
la = Laplace(
model, 'classification', subset_of_weights='last_layer', hessian_structure='kron'
)
la.fit(train_loader)
la.optimize_prior_precision(method='marglik')

Expand All @@ -61,4 +62,6 @@ def predict(dataloader, model, laplace=False):
ece_laplace = ECE(bins=15).measure(probs_laplace.numpy(), targets.numpy())
nll_laplace = -dists.Categorical(probs_laplace).log_prob(targets).mean()

print(f'[Laplace] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}')
print(
f'[Laplace] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}'
)
24 changes: 13 additions & 11 deletions examples/helper/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,26 @@
def CIFAR10(train=True, batch_size=None, augm_flag=True):
if batch_size == None:
if train:
batch_size=train_batch_size
batch_size = train_batch_size
else:
batch_size=test_batch_size
batch_size = test_batch_size

transform_base = [transforms.ToTensor()]
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
] + transform_base)
transform_train = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
]
+ transform_base
)
transform_test = transforms.Compose(transform_base)
transform_train = transforms.RandomChoice([transform_train, transform_test])
transform = transform_train if (augm_flag and train) else transform_test

dataset = datasets.CIFAR10(path, train=train, transform=transform, download=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=train, num_workers=4)
loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=train, num_workers=4
)

return loader

Expand All @@ -38,9 +42,7 @@ def get_sinusoid_example(n_data=150, sigma_noise=0.3, batch_size=150):
X_train = (torch.rand(n_data) * 8).unsqueeze(-1)
y_train = torch.sin(X_train) + torch.randn_like(X_train) * sigma_noise
train_loader = data_utils.DataLoader(
data_utils.TensorDataset(X_train, y_train),
batch_size=batch_size
data_utils.TensorDataset(X_train, y_train), batch_size=batch_size
)
X_test = torch.linspace(-5, 13, 500).unsqueeze(-1)
return X_train, y_train, train_loader, X_test

23 changes: 16 additions & 7 deletions examples/helper/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ def download_pretrained_model():
if not os.path.exists('./temp'):
os.makedirs('./temp')

urllib.request.urlretrieve('https://nc.mlcloud.uni-tuebingen.de/index.php/s/2PBDYDsiotN76mq/download', './temp/CIFAR10_plain.pt')
urllib.request.urlretrieve(
'https://nc.mlcloud.uni-tuebingen.de/index.php/s/2PBDYDsiotN76mq/download',
'./temp/CIFAR10_plain.pt',
)


def plot_regression(X_train, y_train, X_test, f_test, y_std, plot=True,
file_name='regression_example'):
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharey=True,
figsize=(4.5, 2.8))
def plot_regression(
X_train, y_train, X_test, f_test, y_std, plot=True, file_name='regression_example'
):
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(4.5, 2.8))
ax1.set_title('MAP')
ax1.scatter(X_train.flatten(), y_train.flatten(), alpha=0.3, color='tab:orange')
ax1.plot(X_test, f_test, color='black', label='$f_{MAP}$')
Expand All @@ -24,8 +27,14 @@ def plot_regression(X_train, y_train, X_test, f_test, y_std, plot=True,
ax2.set_title('LA')
ax2.scatter(X_train.flatten(), y_train.flatten(), alpha=0.3, color='tab:orange')
ax2.plot(X_test, f_test, label='$\mathbb{E}[f]$')
ax2.fill_between(X_test, f_test-y_std*2, f_test+y_std*2,
alpha=0.3, color='tab:blue', label='$2\sqrt{\mathbb{V}\,[y]}$')
ax2.fill_between(
X_test,
f_test - y_std * 2,
f_test + y_std * 2,
alpha=0.3,
color='tab:blue',
label='$2\sqrt{\mathbb{V}\,[y]}$',
)
ax2.legend()
ax1.set_ylim([-4, 6])
ax1.set_xlim([X_test.min(), X_test.max()])
Expand Down
Loading