diff --git a/bin/sample.py b/bin/sample.py index 5d4015c..97cfccf 100644 --- a/bin/sample.py +++ b/bin/sample.py @@ -281,6 +281,7 @@ def build_parser() -> argparse.ArgumentParser: parser.add_argument( "--testcomparison", action="store_true", help="Run comparison against test set" ) + parser.add_argument("--nopsea", action="store_true", help="Skip PSEA calculations") parser.add_argument("--seed", type=int, default=SEED, help="Random seed") parser.add_argument("--device", type=str, default="cuda:0", help="Device to use") return parser @@ -453,18 +454,19 @@ def main() -> None: ) # Generate plots of secondary structure co-occurrence - make_ss_cooccurrence_plot( - pdb_files, - str(outdir / "plots" / "ss_cooccurrence_sampled.pdf"), - threads=multiprocessing.cpu_count(), - ) - if args.testcomparison: + if not args.nopsea: make_ss_cooccurrence_plot( - test_dset.filenames, - str(outdir / "plots" / "ss_cooccurrence_test.pdf"), - max_seq_len=test_dset.dset.pad, + pdb_files, + str(outdir / "plots" / "ss_cooccurrence_sampled.pdf"), threads=multiprocessing.cpu_count(), ) + if args.testcomparison: + make_ss_cooccurrence_plot( + test_dset.filenames, + str(outdir / "plots" / "ss_cooccurrence_test.pdf"), + max_seq_len=test_dset.dset.pad, + threads=multiprocessing.cpu_count(), + ) if __name__ == "__main__": diff --git a/foldingdiff/sampling.py b/foldingdiff/sampling.py index a1c3cb0..8726568 100644 --- a/foldingdiff/sampling.py +++ b/foldingdiff/sampling.py @@ -54,8 +54,8 @@ def p_sample( # Create the attention mask attn_mask = torch.zeros(x.shape[:2], device=x.device) - for i, l in enumerate(seq_lens): - attn_mask[i, :l] = 1.0 + for i, length in enumerate(seq_lens): + attn_mask[i, :length] = 1.0 # Equation 11 in the paper # Use our model (noise predictor) to predict the mean @@ -140,6 +140,7 @@ def sample( batch_size: int = 512, feature_key: str = "angles", disable_pbar: bool = False, + trim_to_length: bool = True, # Trim padding regions to reduce memory ) -> List[np.ndarray]: """ Sample from the given model. Use the train_dset to generate noise to sample @@ -157,6 +158,10 @@ def sample( # Process each batch if sweep_lengths is not None: sweep_min, sweep_max = sweep_lengths + if not sweep_min < sweep_max: + raise ValueError( + f"Minimum length {sweep_min} must be less than maximum {sweep_max}" + ) logging.info( f"Sweeping from {sweep_min}-{sweep_max} with {n} examples at each length" ) @@ -177,6 +182,11 @@ def sample( noise = train_dset.sample_noise( torch.zeros((batch, train_dset.pad, model.n_inputs), dtype=torch.float32) ) + + # Trim things that are beyond the length of what we are generating + if trim_to_length: + noise = noise[:, : max(this_lengths), :] + # Produces (timesteps, batch_size, seq_len, n_ft) sampled = p_sample_loop( model=model, @@ -255,7 +265,7 @@ def sample_simple( def _score_angles( - reconst_angles:pd.DataFrame, truth_angles:pd.DataFrame, truth_coords_pdb: str + reconst_angles: pd.DataFrame, truth_angles: pd.DataFrame, truth_coords_pdb: str ) -> Tuple[float, float]: """ Helper function to scores sets of angles @@ -348,6 +358,7 @@ def get_reconstruction_error( if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - s = sample_simple("wukevin/foldingdiff_cath", n=1, sweep_lengths=(50, 55)) + s = sample_simple("wukevin/foldingdiff_cath", n=1, sweep_lengths=(50, 51)) for i, x in enumerate(s): print(x.shape) + print(x)