-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_af2_step4.py
126 lines (104 loc) · 4 KB
/
run_af2_step4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python
#-*- coding:utf-8 -*-
###############################################
#
#
# Run AlphaFold2 step by step
# (https://github.com/deepmind/alphafold)
# Author: Pan Li ([email protected])
# @ Shuimu BioScience
# https://www.shuimubio.com/
#
#
################################################
#
#
# AlphaFold2 Step 4 -- Sort models
# Usage: run_af2_step4.py result_1.pkl,...,result_5.pkl pdb_1.pdb,...,pdb_5.pdb /path/to/output
#
#
import json
import os
import pathlib
import pickle
import random
import shutil
import sys
import time
import gzip
from typing import Dict, Union, Optional
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
import argparse
parser = argparse.ArgumentParser(description='AlphaFold2 Step 4 -- Sort models')
parser.add_argument('model_pkl_list', metavar='model_pkl_list', type=str, help='The result_model.pkl files generated by AlphaFold2 step 2, seperated by comma.')
parser.add_argument('pdb_file_list', metavar='pdb_file_list', type=str, help='The PDB files generated by AlphaFold2 step 2 or 3, seperated by comma.')
parser.add_argument('output_dir', metavar='output_dir', type=str, help='Path to a directory that will store the results.')
args = parser.parse_args()
######################
## Check inputs
######################
model_pkl_fn_list = args.model_pkl_list.split(',')
pdb_fn_list = args.pdb_file_list.split(',')
assert len(model_pkl_fn_list) == len(pdb_fn_list), "Error: model_pkl_list and pdb_file_list must be the same length"
output_dir = args.output_dir
assert os.path.exists(output_dir), "Error: output_dir does not exists"
######################
## Read inputs
######################
model_pkl_list = []
for idx, (model_pkl_fn, model_pdb_fn) in enumerate(zip(model_pkl_fn_list, pdb_fn_list)):
if model_pkl_fn.endswith('.gz'):
prediction_result = pickle.load(gzip.open(model_pkl_fn, 'rb'))
else:
prediction_result = pickle.load(open(model_pkl_fn, 'rb'))
assert model_pdb_fn.endswith('.pdb')
model_name = os.path.basename(model_pdb_fn)
model_name = model_name[model_name.find('model_'):-4]
model_pkl_list.append({
'ranking_confidence': prediction_result['ranking_confidence'],
'plddt': prediction_result['plddt'],
'predicted_aligned_error': prediction_result.get('predicted_aligned_error', None),
'model_name': model_name
})
del prediction_result
######################
## Sort
######################
dec_order = np.argsort( [obj['ranking_confidence'] for obj in model_pkl_list] )[::-1]
######################
## Rename PDB file and log file
######################
for rank,i in enumerate(dec_order):
shutil.copyfile( pdb_fn_list[i], os.path.join(output_dir, f'ranked_{rank+1}.pdb') )
log = {
"plddt": { model_pkl_list[i]['model_name']: model_pkl_list[i]['ranking_confidence'] for i in range(len(model_pkl_list)) },
"order": [ model_pkl_list[i]['model_name'] for i in dec_order ]
}
with open( os.path.join(output_dir, 'ranking_debug.json'), 'w' ) as OUT:
json.dump(log, OUT, indent=4)
######################
## Plot for visualization
######################
for rank,idx in enumerate(dec_order):
model_name = model_pkl_list[idx]['model_name']
if model_pkl_list[idx]['predicted_aligned_error'] is not None:
plt.figure(figsize=(8, 6))
sns.heatmap( model_pkl_list[idx]['predicted_aligned_error'], vmin=0.0, vmax=20.0 )
plt.xlabel("Residual")
plt.ylabel("Residual")
plt.title(f"{model_name} rank={rank+1}")
plt.tight_layout()
plt.savefig( f"{output_dir}/{model_name}_predicted_aligned_error.png" )
plt.close()
plddt = model_pkl_list[idx]['plddt']
plt.figure(figsize=(12, 4))
plt.bar(range(len(plddt)), plddt, linewidth=0.8)
plt.xlabel("Residual")
plt.ylabel("pLDDT")
plt.title(f"{model_name} rank={rank+1}")
plt.tight_layout()
plt.savefig( f"{output_dir}/{model_name}_plddt.png" )
plt.close()