Skip to content

Commit

Permalink
Merge pull request #13 from badeok0716/main
Browse files Browse the repository at this point in the history
Correcting Acquisition Function Calculation
  • Loading branch information
samuelstanton authored Apr 20, 2024
2 parents 04117b3 + 4977889 commit 6d672b3
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 28 deletions.
3 changes: 2 additions & 1 deletion lambo/acquisitions/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,19 @@ def forward(self, X: array) -> Tensor:
baseline_X = self._X_baseline
baseline_X = baseline_X.expand(*X.shape[:-2], -1, -1)
X_full = torch.cat([baseline_X, X], dim=-2)
q = X.shape[-2]
else:
baseline_X = copy(self.X_baseline_string) # ensure contiguity
baseline_X.resize(
baseline_X.shape[:-(X.ndim)] + X.shape[:-1] + baseline_X.shape[-1:]
)
X_full = concatenate([baseline_X, X], axis=-1)
q = X.shape[-1]
# Note: it is important to compute the full posterior over `(X_baseline, X)``
# to ensure that we properly sample `f(X)` from the joint distribution `
# `f(X_baseline, X) ~ P(f | D)` given that we can already fixed the sampled
# function values for `f(X_baseline)`
posterior = self.model.posterior(X_full)
q = X.shape[-2]
self._set_sampler(q=q, posterior=posterior)
samples = self.sampler(posterior)[..., -q:, :]
# add previous nehvi from pending points
Expand Down
7 changes: 0 additions & 7 deletions lambo/candidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@

from lambo.utils import StringSubstitution, StringDeletion, StringInsertion, FoldxMutation


def apply_mutation(base_seq, mut_pos, mut_res, tokenizer):
tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")[1:-1]
mut_seq = "".join(tokens[:mut_pos] + [mut_res] + tokens[(mut_pos + 1):])
return mut_seq


def pdb_to_residues(pdb_path, chain_id='A'):
"""
:param pdb_path: path to pdb file (str or Path)
Expand Down
2 changes: 1 addition & 1 deletion lambo/models/base_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class BaseSurrogate(torch.nn.Module):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float
dtype = torch.double

def _set_transforms(self, tokenizer, max_shift, mask_size, train_prepend=None):
# convert from string to LongTensor of token indexes
Expand Down
21 changes: 4 additions & 17 deletions lambo/optimizers/lambo.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,7 @@ def optimize(self, candidate_pool, pool_targets, all_seqs, all_targets, log_pref
else:
raise ValueError

# import pdb; pdb.set_trace()
lat_acq_vals = acq_fn(pooled_features.unsqueeze(-2))
# lat_acq_vals = acq_fn(pooled_features.unsqueeze(0))
lat_acq_vals = acq_fn(pooled_features.unsqueeze(0))
loss = -lat_acq_vals.mean() + self.entropy_penalty * logit_entropy.mean()

if self.optimize_latent:
Expand All @@ -322,31 +320,20 @@ def optimize(self, candidate_pool, pool_targets, all_seqs, all_targets, log_pref

tgt_seqs = tokens_to_str(tgt_tok_idxs, self.encoder.tokenizer)
with torch.no_grad():
act_acq_vals = acq_fn(tgt_seqs[..., None])
# act_acq_vals = acq_fn(tgt_seqs[None, :]).mean().item()

is_improved = (act_acq_vals >= best_scores)
best_scores = torch.where(is_improved, act_acq_vals, best_scores)
best_seqs = np.where(is_improved.cpu().numpy(), tgt_seqs, best_seqs)
# best_scores[is_improved] = act_acq_vals[is_improved]
# best_seqs[is_improved] = tgt_seqs[is_improved]

with torch.no_grad():
batch_acq_val = acq_fn(best_seqs[None, :]).mean().item()
curr_score = -1.0 * batch_acq_val
act_acq_vals = acq_fn(tgt_seqs[None, :]).mean().item()

best_score, best_step, _, stop = check_early_stopping(
model=None,
best_score=best_score,
best_epoch=best_step,
best_weights=None,
curr_score=curr_score,
curr_score=-act_acq_vals,
curr_epoch=step_idx + 1,
patience=self.patience,
save_weights=False,
)
if (step_idx + 1) == best_step:
# best_seqs = tgt_seqs.copy()
best_seqs = tgt_seqs.copy()
best_entropy = logit_entropy.mean().item()
if stop:
print(f"Early stopping at step {step_idx + 1}")
Expand Down
1 change: 0 additions & 1 deletion lambo/tasks/surrogate_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def _evaluate(self, x, out, *args, **kwargs):
cand_idx, mut_pos, mut_res_idx, op_idx = query_pt
op_type = self.op_types[op_idx]
base_seq = self.candidate_pool[cand_idx].mutant_residue_seq
mut_pos = mut_pos % len(base_seq)
mut_res = self.tokenizer.sampling_vocab[mut_res_idx]
mutant_seq = apply_mutation(base_seq, mut_pos, mut_res, op_type, self.tokenizer)
candidates.append(mutant_seq)
Expand Down
3 changes: 2 additions & 1 deletion lambo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ def tokens_to_str(tok_idx_array, tokenizer):


def apply_mutation(base_seq, mut_pos, mut_res, op_type, tokenizer):
tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")[1:-1]
tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")
mut_pos = mut_pos % len(tokens)

if op_type == 'sub':
mut_seq = "".join(tokens[:mut_pos] + [mut_res] + tokens[(mut_pos + 1):])
Expand Down

0 comments on commit 6d672b3

Please sign in to comment.