Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with Acquisition Function Calculation in LAMBO Implementation #14

Closed
badeok0716 opened this issue Jan 4, 2024 · 2 comments
Closed

Comments

@badeok0716
Copy link
Contributor

badeok0716 commented Jan 4, 2024

I'm currently working on biological sequence design experiments based on the LAMBO paper and considering LAMBO as one of the main baselines.
I appreciate the authors for providing well-reproducible experiments. I could easily reproduce Figure 3 based on this repository.

However, the implementation of this repo seems to have issues in calculating the acquisition function NoisyEHVI.

My concern is about lambo/optimizers/lambo.py Line 335:
batch_acq_val = acq_fn(best_seqs[None, :]).mean().item()
*best_seqs[None, :]: numpy array of strings with shape [1, batch_size]

The function call of acq_fn is based on https://github.com/samuelstanton/lambo/blob/main/lambo/acquisitions/monte_carlo.py#L69-L89.

def forward(self, X: array) -> Tensor:
     if isinstance(X, Tensor):
         baseline_X = self._X_baseline
         baseline_X = baseline_X.expand(*X.shape[:-2], -1, -1)
         X_full = torch.cat([baseline_X, X], dim=-2)
     else: 
         baseline_X = copy(self.X_baseline_string) # ensure contiguity
         baseline_X.resize(
             baseline_X.shape[:-(X.ndim)] + X.shape[:-1] + baseline_X.shape[-1:]
         )
         X_full = concatenate([baseline_X, X], axis=-1)
     # Note: it is important to compute the full posterior over `(X_baseline, X)``
     # to ensure that we properly sample `f(X)` from the joint distribution `
     # `f(X_baseline, X) ~ P(f | D)` given that we can already fixed the sampled
     # function values for `f(X_baseline)`
     posterior = self.model.posterior(X_full)
     q = X.shape[-2]
     self._set_sampler(q=q, posterior=posterior)
     samples = self.sampler(posterior)[..., -q:, :]
     # add previous nehvi from pending points
     return self._compute_qehvi(samples=samples) + self._prev_nehvi

Note that
X : numpy array of strings with shape [1, batch_size]
self.X_baseline_string: numpy array of sequences with shape [n_base]
X_full : numpy array of strings with shape [1, batch_size + n_base]
Hence,
q = X.shape[-2] equals to 1, and

        self._set_sampler(q=q, posterior=posterior)
        samples = self.sampler(posterior)[..., -q:, :] 
        # add previous nehvi from pending points
        return self._compute_qehvi(samples=samples) + self._prev_nehvi

this part only calculates 1-NEHVI of the last sequence of X (=X[0,-1]). In other words, other sequences in X (X[0,:-1]) are ignored during the calculation.
(Note that samples is torch Tensor of the shape [n_samples, 1, q, fdim].)
This problem occurs when X is numpy array of sequences with X.ndim=2 and X.shape[-1]>1.
I think, the problem originated from q = X.shape[-2], which assumes that X is extracted feature from sequences.

This causes two problems

  1. [LAMBO] Incorrect logic to determine the best_batch_idx. (https://github.com/samuelstanton/lambo/blob/main/lambo/optimizers/lambo.py#L355-L365)
  2. [MBGA] Incorrect NEHVI calculation in SurrogateTask.
    (https://github.com/samuelstanton/lambo/blob/main/lambo/tasks/surrogate_task.py#L28)

There can be two solutions based on authors' original intention.

  1. If authors wanted to compute N-NEHVI of X[0,:] (X: numpy array of shape [1, N]) when calling acq_fn(X):
def forward(self, X: array, debug:bool=True) -> Tensor:
      if isinstance(X, Tensor):
            baseline_X = self._X_baseline
            baseline_X = baseline_X.expand(*X.shape[:-2], -1, -1)
            X_full = torch.cat([baseline_X, X], dim=-2)
      elif X.ndim == 2:
            # X : (1, N)
            # To calculate 1-NEHVI, parallely (N in once)
            assert X.shape[0]==1, f"X type {type(X)}, X ndim {X.ndim}, X shape {X.shape}"
            X = self.model.get_features(X[0], self.model.bs) # X : (N, 16)
            X = X.unsqueeze(0) # X : (1, N, 16)
            baseline_X = self._X_baseline # baseline_X : (1, n, 16)
            baseline_X = baseline_X.expand(*X.shape[:-2], -1, -1) # baseline_X : (1, n, 16)
            X_full = torch.cat([baseline_X, X], dim=-2) # X_full : (1, n+N, 16)
  1. If authors wanted to compute the average value of 1-NEHVI of X[0,i]s (X: numpy array of shape [1, N]) when calling acq_fn(X):
def forward(self, X: array, debug:bool=True) -> Tensor:
      if isinstance(X, Tensor):
            baseline_X = self._X_baseline
            baseline_X = baseline_X.expand(*X.shape[:-2], -1, -1)
            X_full = torch.cat([baseline_X, X], dim=-2)
      elif X.ndim == 2:
            # X : (1, N)
            # To calculate 1-NEHVI, parallely (N in once)
            assert X.shape[0]==1, f"X type {type(X)}, X ndim {X.ndim}, X shape {X.shape}"
            X = self.model.get_features(X[0], self.model.bs) # X : (N, 16)
            X = X.unsqueeze(-2) # X : (N, 1, 16)
            baseline_X = self._X_baseline # baseline_X : (1, n, 16)
            baseline_X = baseline_X.expand(*X.shape[:-2], -1, -1) # baseline_X : (N, n, 16)
            X_full = torch.cat([baseline_X, X], dim=-2) # X_full : (N, n+1, 16)

Since I couldn't find any mention of parallel acquisition function (q-NEHVI for q > 1 case) in the LAMBO paper, I conducted version (2.) and got the following results with the same commands (10 seeds).

스크린샷 2024-01-04 오후 6 48 20

(I'll update the figure after I finish the LaMBO-Fixed and MBGA-Fixed experiments on the other tasks. Some trials are not finished yet, but fixed versions seem to show better performance on the other tasks.)

Could you please verify if my concern is valid? Also, I'd like to know if the corrected implementation aligns with your original intent.
Please refer to Pull request #13

@badeok0716
Copy link
Contributor Author

badeok0716 commented Jan 11, 2024

I summarized three issues (the issue about q-NEHVI calculation and two additional issues) in this repo and illustrated the performance after the correction in here.

@samuelstanton
Copy link
Owner

closing the issue as the fix has been merged. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants