diff --git a/lambo/acquisitions/monte_carlo.py b/lambo/acquisitions/monte_carlo.py index f8e84a5..e334d34 100644 --- a/lambo/acquisitions/monte_carlo.py +++ b/lambo/acquisitions/monte_carlo.py @@ -71,18 +71,19 @@ def forward(self, X: array) -> Tensor: baseline_X = self._X_baseline baseline_X = baseline_X.expand(*X.shape[:-2], -1, -1) X_full = torch.cat([baseline_X, X], dim=-2) + q = X.shape[-2] else: baseline_X = copy(self.X_baseline_string) # ensure contiguity baseline_X.resize( baseline_X.shape[:-(X.ndim)] + X.shape[:-1] + baseline_X.shape[-1:] ) X_full = concatenate([baseline_X, X], axis=-1) + q = X.shape[-1] # Note: it is important to compute the full posterior over `(X_baseline, X)`` # to ensure that we properly sample `f(X)` from the joint distribution ` # `f(X_baseline, X) ~ P(f | D)` given that we can already fixed the sampled # function values for `f(X_baseline)` posterior = self.model.posterior(X_full) - q = X.shape[-2] self._set_sampler(q=q, posterior=posterior) samples = self.sampler(posterior)[..., -q:, :] # add previous nehvi from pending points diff --git a/lambo/candidate.py b/lambo/candidate.py index 7b632d3..411de03 100644 --- a/lambo/candidate.py +++ b/lambo/candidate.py @@ -8,13 +8,6 @@ from lambo.utils import StringSubstitution, StringDeletion, StringInsertion, FoldxMutation - -def apply_mutation(base_seq, mut_pos, mut_res, tokenizer): - tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")[1:-1] - mut_seq = "".join(tokens[:mut_pos] + [mut_res] + tokens[(mut_pos + 1):]) - return mut_seq - - def pdb_to_residues(pdb_path, chain_id='A'): """ :param pdb_path: path to pdb file (str or Path) diff --git a/lambo/models/base_surrogate.py b/lambo/models/base_surrogate.py index 994c2f9..ced5993 100644 --- a/lambo/models/base_surrogate.py +++ b/lambo/models/base_surrogate.py @@ -10,7 +10,7 @@ class BaseSurrogate(torch.nn.Module): device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') - dtype = torch.float + dtype = torch.double def _set_transforms(self, tokenizer, max_shift, mask_size, train_prepend=None): # convert from string to LongTensor of token indexes diff --git a/lambo/optimizers/lambo.py b/lambo/optimizers/lambo.py index f203f90..af9fae7 100644 --- a/lambo/optimizers/lambo.py +++ b/lambo/optimizers/lambo.py @@ -310,9 +310,7 @@ def optimize(self, candidate_pool, pool_targets, all_seqs, all_targets, log_pref else: raise ValueError - # import pdb; pdb.set_trace() - lat_acq_vals = acq_fn(pooled_features.unsqueeze(-2)) - # lat_acq_vals = acq_fn(pooled_features.unsqueeze(0)) + lat_acq_vals = acq_fn(pooled_features.unsqueeze(0)) loss = -lat_acq_vals.mean() + self.entropy_penalty * logit_entropy.mean() if self.optimize_latent: @@ -322,31 +320,20 @@ def optimize(self, candidate_pool, pool_targets, all_seqs, all_targets, log_pref tgt_seqs = tokens_to_str(tgt_tok_idxs, self.encoder.tokenizer) with torch.no_grad(): - act_acq_vals = acq_fn(tgt_seqs[..., None]) - # act_acq_vals = acq_fn(tgt_seqs[None, :]).mean().item() - - is_improved = (act_acq_vals >= best_scores) - best_scores = torch.where(is_improved, act_acq_vals, best_scores) - best_seqs = np.where(is_improved.cpu().numpy(), tgt_seqs, best_seqs) - # best_scores[is_improved] = act_acq_vals[is_improved] - # best_seqs[is_improved] = tgt_seqs[is_improved] - - with torch.no_grad(): - batch_acq_val = acq_fn(best_seqs[None, :]).mean().item() - curr_score = -1.0 * batch_acq_val + act_acq_vals = acq_fn(tgt_seqs[None, :]).mean().item() best_score, best_step, _, stop = check_early_stopping( model=None, best_score=best_score, best_epoch=best_step, best_weights=None, - curr_score=curr_score, + curr_score=-act_acq_vals, curr_epoch=step_idx + 1, patience=self.patience, save_weights=False, ) if (step_idx + 1) == best_step: - # best_seqs = tgt_seqs.copy() + best_seqs = tgt_seqs.copy() best_entropy = logit_entropy.mean().item() if stop: print(f"Early stopping at step {step_idx + 1}") diff --git a/lambo/tasks/surrogate_task.py b/lambo/tasks/surrogate_task.py index b6d0b15..7de4f78 100644 --- a/lambo/tasks/surrogate_task.py +++ b/lambo/tasks/surrogate_task.py @@ -19,7 +19,6 @@ def _evaluate(self, x, out, *args, **kwargs): cand_idx, mut_pos, mut_res_idx, op_idx = query_pt op_type = self.op_types[op_idx] base_seq = self.candidate_pool[cand_idx].mutant_residue_seq - mut_pos = mut_pos % len(base_seq) mut_res = self.tokenizer.sampling_vocab[mut_res_idx] mutant_seq = apply_mutation(base_seq, mut_pos, mut_res, op_type, self.tokenizer) candidates.append(mutant_seq) diff --git a/lambo/utils.py b/lambo/utils.py index 7601888..a677f62 100644 --- a/lambo/utils.py +++ b/lambo/utils.py @@ -432,7 +432,8 @@ def tokens_to_str(tok_idx_array, tokenizer): def apply_mutation(base_seq, mut_pos, mut_res, op_type, tokenizer): - tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")[1:-1] + tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ") + mut_pos = mut_pos % len(tokens) if op_type == 'sub': mut_seq = "".join(tokens[:mut_pos] + [mut_res] + tokens[(mut_pos + 1):])