Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added threading and Colabfold-like plots generation #442

Open
wants to merge 64 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
896d586
added multithread to colabfold search
jflucier Apr 17, 2024
17e3b93
added multithread to colabfold search
jflucier Apr 17, 2024
573623d
added multithread to colabfold search
jflucier Apr 17, 2024
d256af3
trying to run
jflucier Apr 22, 2024
89c0164
tests
jflucier Apr 23, 2024
acb0dcc
tests
jflucier Apr 23, 2024
e8fa8d2
tests
jflucier Apr 23, 2024
93caf7e
test
jflucier Apr 23, 2024
dd6f9bc
test
jflucier Apr 23, 2024
bfb2cac
test
jflucier Apr 23, 2024
be45f53
test
jflucier Apr 23, 2024
f256b4b
test
jflucier Apr 23, 2024
ec7f765
test
jflucier Apr 23, 2024
e995a1a
test
jflucier Apr 23, 2024
5ba5334
test
jflucier Apr 23, 2024
00d6f66
test
jflucier Apr 23, 2024
4eaf57d
test
jflucier Apr 23, 2024
752aa79
test
jflucier Apr 23, 2024
6454eb3
test
jflucier Apr 23, 2024
866ca58
test
jflucier Apr 23, 2024
126cb22
test
jflucier Apr 23, 2024
875246b
test
jflucier Apr 23, 2024
2dea90f
test
jflucier Apr 23, 2024
25fc5a4
test
jflucier Apr 23, 2024
3ad997f
test
jflucier Apr 23, 2024
220e1ce
added logging
jflucier Apr 25, 2024
3c35fb4
added logging
jflucier Apr 25, 2024
2b2db7c
added logging
jflucier Apr 25, 2024
d377bc4
added QC plot script
jflucier Apr 25, 2024
14a329d
patch plot script
jflucier Apr 25, 2024
b1e4a84
patch plot script
jflucier Apr 25, 2024
b4d5a08
patch to save feature dict in pickle for post plots
jflucier Apr 25, 2024
6c4e80a
patch to save feature dict in pickle for post plots
jflucier Apr 26, 2024
9fc31ef
start
jflucier Apr 26, 2024
21d1c9b
plots generation scripts
jflucier Apr 29, 2024
ee207c4
change input param for coverage
jflucier Apr 29, 2024
1a410ec
patch
jflucier Apr 29, 2024
68d1e6b
patch
jflucier Apr 29, 2024
9812856
unrecon fonts
jflucier Apr 29, 2024
322f74c
patch for feature dict pickle output
jflucier May 1, 2024
30bd52b
gen pae json output
jflucier May 1, 2024
df5884a
gen pae json output
jflucier May 1, 2024
de95dda
gen pae json output
jflucier May 1, 2024
03d24f2
export json like colabfold
jflucier May 3, 2024
bdfda8f
debug hhsearch
jflucier May 8, 2024
6974cff
debug hhsearch
jflucier May 8, 2024
45ac61b
debug hhsearch
jflucier May 8, 2024
79464ce
debug hhsearch
jflucier May 8, 2024
cb359fd
debug hhsearch
jflucier May 8, 2024
c09e7a8
debug hhsearch
jflucier May 8, 2024
c549f8e
debug hhsearch
jflucier May 8, 2024
6717ee6
colabsearch integration
jflucier May 9, 2024
6abf48f
gen plots new version
jflucier May 13, 2024
60a0b8b
add table to pae plot
jflucier May 14, 2024
c4e1400
split json creation from plot script
jflucier May 14, 2024
9e9dd62
split json creation from plot script
jflucier May 14, 2024
ceb056a
split json creation from plot script
jflucier May 14, 2024
ca54002
split json creation from plot script
jflucier May 14, 2024
f96ecea
improve exception messages
max-l Jul 19, 2024
2dbf243
patch to display multimer seq limit in plots
jflucier Jul 23, 2024
22f1ebb
patch to display multimer seq limit in plots
jflucier Jul 23, 2024
ae77037
patch to display multimer seq limit in plots
jflucier Jul 23, 2024
7ff67eb
patch to display multimer seq limit in plots
jflucier Jul 23, 2024
724db4e
patch to display multimer seq limit in plots
jflucier Jul 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,834 changes: 917 additions & 917 deletions notebooks/OpenFold.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion openfold/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def model_config(
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
else:
raise ValueError("Invalid model name")
raise ValueError(f"Invalid model name {name}")

if long_sequence_inference:
assert(not train)
Expand Down
3 changes: 1 addition & 2 deletions openfold/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@ def read_msa(start, size):
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
if not os.path.exists(uniprot_msa_path):
chain_id = os.path.basename(os.path.normpath(alignment_dir))
raise ValueError(f"Missing 'uniprot_hits.sto' for {chain_id}. "
raise ValueError(f"Missing file {uniprot_msa_path} for {chain_id}. "
f"This is required for Multimer MSA pairing.")

with open(uniprot_msa_path, "r") as fp:
Expand All @@ -1235,7 +1235,6 @@ def process_fasta(self,
input_fasta_str = f.read()

input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)

all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
Expand Down
4 changes: 2 additions & 2 deletions openfold/data/tools/hhsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def query(self, a3m: str, output_dir: Optional[str] = None) -> str:
if retcode:
# Stderr is truncated to prevent proto size errors in Beam.
raise RuntimeError(
"HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
"HHSearch failed:\ncommand:\n%s\n\nstdout:\n%s\n\nstderr:\n%s\n"
% (f"hhsearch command: {' '.join(cmd)}", stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
)

with open(hhr_path) as f:
Expand Down
57 changes: 55 additions & 2 deletions run_pretrained_openfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def generate_feature_dict(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
fasta_path=tmp_fasta_path,
alignment_dir=alignment_dir
)
elif len(seqs) == 1:
tag = tags[0]
Expand Down Expand Up @@ -180,6 +181,45 @@ def main(args):

config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)

print("")
print("#### INPUT / OUTPUT ####")
print(f"fasta_dir: {args.fasta_dir}")
print(f"output_dir: {args.output_dir}")
print(f"output prediction filenames: {args.output_postfix}")
print(f"cif_output: {args.cif_output}")
print(f"save embedded outputs: {args.save_outputs}")

print("")
print("#### PRESETS ####")
print(f"skip_relaxation: {args.skip_relaxation}")
print(f"use_precomputed_alignments: {args.use_precomputed_alignments}")
print(f"use_single_seq_mode: {args.use_single_seq_mode}")
print(f"long_sequence_inference: {args.long_sequence_inference}")
print(f"Threads: {args.cpus}")
print(f"multimer_ri_gap: {args.multimer_ri_gap}")
print(f"subtract_plddt: {args.subtract_plddt}")

print("")
print("#### MODEL PARAMS ####")
print(f"Model: {args.config_preset}")
print(f"trace_model: {args.trace_model}")

print("")
print("#### DATABASE PARAMS ####")
print(f"template_mmcif_dir: {args.template_mmcif_dir}")
print(f"max_template_date: {args.max_template_date}")
print(f"max_templates: {config.data.predict.max_templates}")
print(f"release_dates_path: {args.release_dates_path}")
print(f"obsolete_pdbs_path: {args.obsolete_pdbs_path}")

print("")
print("#### GPU / AI PARAMS ####")
print(f"model_device: {args.model_device}")
print(f"openfold_checkpoint_path: {args.openfold_checkpoint_path}")
print(f"jax_param_path: {args.jax_param_path}")

print("")

if args.trace_model:
if not config.data.predict.fixed_size:
raise ValueError(
Expand Down Expand Up @@ -237,6 +277,7 @@ def main(args):
for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
# Gather input sequences
fasta_path = os.path.join(args.fasta_dir, fasta_file)
print(f"reading fasta: {fasta_path}")
with open(fasta_path, "r") as fp:
data = fp.read()

Expand All @@ -258,12 +299,15 @@ def main(args):
seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
feature_dicts = {}

logger.info(f"Loading model information...")
model_generator = load_models_from_command_line(
config,
args.model_device,
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
args.output_dir
)

for model, output_directory in model_generator:
cur_tracing_interval = 0
Expand All @@ -273,6 +317,7 @@ def main(args):
output_name = f'{output_name}_{args.output_postfix}'

# Does nothing if the alignments have already been computed
logger.info(f"Perform alignment if not already done...")
precompute_alignments(tags, seqs, alignment_dir, args)

feature_dict = feature_dicts.get(tag, None)
Expand All @@ -298,6 +343,10 @@ def main(args):
feature_dict, mode='predict', is_multimer=is_multimer
)

# print("Storing feature dict...")
# with open(os.path.join(args.output_dir, f"{output_name}_feature_dict.pickle"), "wb") as fp:
# pickle.dump(processed_feature_dict, fp, protocol=pickle.HIGHEST_PROTOCOL)

processed_feature_dict = {
k: torch.as_tensor(v, device=args.model_device)
for k, v in processed_feature_dict.items()
Expand All @@ -316,6 +365,10 @@ def main(args):
)
cur_tracing_interval = rounded_seqlen

print("Storing feature dict...")
with open(os.path.join(args.output_dir, f"{output_name}_feature_dict.pickle"), "wb") as fp:
pickle.dump(feature_dict, fp, protocol=pickle.HIGHEST_PROTOCOL)

out = run_model(model, processed_feature_dict, tag, args.output_dir)

# Toss out the recycling dimensions --- we don't need them anymore
Expand Down
Loading