Skip to content

Commit

Permalink
Speed up sampling by truncating input padding regions.
Browse files Browse the repository at this point in the history
  • Loading branch information
wukevin committed Sep 17, 2023
1 parent b59a937 commit 8e784ad
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
20 changes: 11 additions & 9 deletions bin/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def build_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--testcomparison", action="store_true", help="Run comparison against test set"
)
parser.add_argument("--nopsea", action="store_true", help="Skip PSEA calculations")
parser.add_argument("--seed", type=int, default=SEED, help="Random seed")
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use")
return parser
Expand Down Expand Up @@ -453,18 +454,19 @@ def main() -> None:
)

# Generate plots of secondary structure co-occurrence
make_ss_cooccurrence_plot(
pdb_files,
str(outdir / "plots" / "ss_cooccurrence_sampled.pdf"),
threads=multiprocessing.cpu_count(),
)
if args.testcomparison:
if not args.nopsea:
make_ss_cooccurrence_plot(
test_dset.filenames,
str(outdir / "plots" / "ss_cooccurrence_test.pdf"),
max_seq_len=test_dset.dset.pad,
pdb_files,
str(outdir / "plots" / "ss_cooccurrence_sampled.pdf"),
threads=multiprocessing.cpu_count(),
)
if args.testcomparison:
make_ss_cooccurrence_plot(
test_dset.filenames,
str(outdir / "plots" / "ss_cooccurrence_test.pdf"),
max_seq_len=test_dset.dset.pad,
threads=multiprocessing.cpu_count(),
)


if __name__ == "__main__":
Expand Down
19 changes: 15 additions & 4 deletions foldingdiff/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def p_sample(

# Create the attention mask
attn_mask = torch.zeros(x.shape[:2], device=x.device)
for i, l in enumerate(seq_lens):
attn_mask[i, :l] = 1.0
for i, length in enumerate(seq_lens):
attn_mask[i, :length] = 1.0

# Equation 11 in the paper
# Use our model (noise predictor) to predict the mean
Expand Down Expand Up @@ -140,6 +140,7 @@ def sample(
batch_size: int = 512,
feature_key: str = "angles",
disable_pbar: bool = False,
trim_to_length: bool = True, # Trim padding regions to reduce memory
) -> List[np.ndarray]:
"""
Sample from the given model. Use the train_dset to generate noise to sample
Expand All @@ -157,6 +158,10 @@ def sample(
# Process each batch
if sweep_lengths is not None:
sweep_min, sweep_max = sweep_lengths
if not sweep_min < sweep_max:
raise ValueError(
f"Minimum length {sweep_min} must be less than maximum {sweep_max}"
)
logging.info(
f"Sweeping from {sweep_min}-{sweep_max} with {n} examples at each length"
)
Expand All @@ -177,6 +182,11 @@ def sample(
noise = train_dset.sample_noise(
torch.zeros((batch, train_dset.pad, model.n_inputs), dtype=torch.float32)
)

# Trim things that are beyond the length of what we are generating
if trim_to_length:
noise = noise[:, : max(this_lengths), :]

# Produces (timesteps, batch_size, seq_len, n_ft)
sampled = p_sample_loop(
model=model,
Expand Down Expand Up @@ -255,7 +265,7 @@ def sample_simple(


def _score_angles(
reconst_angles:pd.DataFrame, truth_angles:pd.DataFrame, truth_coords_pdb: str
reconst_angles: pd.DataFrame, truth_angles: pd.DataFrame, truth_coords_pdb: str
) -> Tuple[float, float]:
"""
Helper function to scores sets of angles
Expand Down Expand Up @@ -348,6 +358,7 @@ def get_reconstruction_error(

if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
s = sample_simple("wukevin/foldingdiff_cath", n=1, sweep_lengths=(50, 55))
s = sample_simple("wukevin/foldingdiff_cath", n=1, sweep_lengths=(50, 51))
for i, x in enumerate(s):
print(x.shape)
print(x)

0 comments on commit 8e784ad

Please sign in to comment.